def get_optim(cfg, model, dataset_iter_num): cfg = cfg.OPTIM optim_name = cfg.NAME optimizer = None assert optim_name in ["FusedLAMB", "AdamW", "Adam", "SGD"], "optimizer not allowed" parameters = filter(lambda p: p.requires_grad, model.parameters()) if optim_name == "FusedLAMB": optimizer = FusedLAMB(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON) if optim_name == "AdamW": optimizer = AdamW(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON) if optim_name == "Adam": optimizer = Adam(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON) if optim_name == "SGD": optimizer = SGD(parameters, lr=cfg.INIT_LR, momentum=cfg.SGD_MOMENTUM) warmup_step = int(cfg.WARM_UP_EPOCH * dataset_iter_num) max_step = cfg.MAX_EPOCH * dataset_iter_num if cfg.USE_LR_SCHEDULER: if cfg.LR_SCHEDULER_TYPE == "get_exponent_schedule_with_warmup": scheduler = get_exponent_schedule_with_warmup( optimizer, warmup_step, exponent=cfg.EXPONENT) else: scheduler = globals()[cfg.LR_SCHEDULER_TYPE](optimizer, warmup_step, max_step) else: scheduler = None return optimizer, scheduler
def __init__(self, args, params): super().__init__(args) try: from apex.optimizers import FusedLAMB self._optimizer = FusedLAMB(params, **self.optimizer_config) except ImportError: raise ImportError('Please install apex to use LAMB optimizer')
def configure_optimizers(self): if self.optimizer == "adam": optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=0.0) elif self.optimizer == "lamb": optimizer = FusedLAMB( self.parameters(), lr=self.learning_rate, weight_decay=0.0, ) elif self.optimizer == "gremlin": from ..optim import GremlinAdam optimizer = GremlinAdam( [{ "params": self.parameters(), "gremlin": True }], lr=self.learning_rate, ) else: raise ValueError(f"Unrecognized optimizer {self.optimizer}") lr_scheduler = lr_schedulers.get(self.lr_scheduler)( optimizer, self.warmup_steps, self.trainer.max_steps) scheduler_dict = { "scheduler": lr_scheduler, "interval": "step", } return [optimizer], [scheduler_dict]
def configure_optimizers(self): if self.optimizer == "adam": optimizer = torch.optim.AdamW( self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff ) elif self.optimizer == "lamb": optimizer = FusedLAMB( self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff, ) else: raise ValueError(f"Unrecognized optimizer {self.optimizer}") return [optimizer]
def get_optimizer(optimizer_name: str, parameters, learning_rate: float, weight_decay=0.0, **kwargs): if optimizer_name.lower() == "sgd": return SGD(parameters, learning_rate, momentum=0.9, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "adam": return Adam(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs) # As Jeremy suggests if optimizer_name.lower() == "rms": return RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "adamw": return AdamW(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs) if optimizer_name.lower() == "radam": return RAdam(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs) # As Jeremy suggests if optimizer_name.lower() == "ranger": return Ranger(parameters, learning_rate, weight_decay=weight_decay, **kwargs) # if optimizer_name.lower() == "qhadamw": # return QHAdamW(parameters, learning_rate, weight_decay=weight_decay, # **kwargs) # if optimizer_name.lower() == "lamb": return Lamb(parameters, learning_rate, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "fused_lamb": from apex.optimizers import FusedLAMB return FusedLAMB(parameters, learning_rate, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "fused_adam": from apex.optimizers import FusedAdam return FusedAdam(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, adam_w_mode=True, **kwargs) if optimizer_name.lower() == "fused_sgd": from apex.optimizers import FusedSGD return FusedSGD(parameters, learning_rate, weight_decay=weight_decay, momentum=0.9, **kwargs) if optimizer_name.lower() == "diffgrad": return DiffGrad(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "novograd": return Novograd(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, **kwargs) raise ValueError("Unsupported optimizer name " + optimizer_name)
def configure_optimizers(self): """Prepare optimizer and schedule (linear warmup and decay)""" model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.hparams.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] if self.hparams.lamb: optimizer = FusedLAMB(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) elif self.hparams.adafactor: optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False) else: optimizer = FusedAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) self.opt = optimizer scheduler = self.get_lr_scheduler() return [optimizer], [scheduler]
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if hasattr(args, 'opt_args') and args.opt_args is not None: opt_args.update(args.opt_args) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': # ================================ # optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) print(' ') print('Gradient centralization is enabled for AdamP optimizer.') print(' ') optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, use_gc=True, gc_conv_only=True, gc_loc=False, **opt_args) # ================================ elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not args.lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: print("has weight decay and filter bias") parameters = add_weight_decay(model, weight_decay) weight_decay = 0. else: print("Comes here to unfrozen params inside optim") parameters = unfrozen_params(model) if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': print("my optimizer") optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def train(args, teacher_args): """Train FCL-taco2 model.""" set_deterministic_pytorch(args) # args.use_fe_condition = True # # pre-occupy GPU # buff = torch.randn(int(1e9)).cuda() # del buff # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims: " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps( (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, args, teacher_args=teacher_args) #print('\n\nteacher_args:', teacher_args.embed_dim, '\n\n') teacher_model_class = dynamic_import(teacher_args.model_module) teacher_model = teacher_model_class(idim, odim, teacher_args, teacher_args) #teacher_model = teacher_model.to('cuda') if teacher_args.amp_checkpoint is None: raise ValueError('please provide the teacher-model-amp-checkpoint') else: logging.info("teacher-model resumed from %s" % teacher_args.amp_checkpoint) teacher_checkpoint = torch.load(teacher_args.amp_checkpoint) teacher_model.load_state_dict(teacher_checkpoint['model']) # print('tts_wds:', model.base_plot_keys) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) # model = torch.nn.DataParallel(model, device_ids=[4,5,6,7]) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) teacher_model = teacher_model.to(device) for param in teacher_model.parameters(): # fix teacher model params param.requires_grad = False # freeze modules, if specified if args.freeze_mods: if hasattr(model, "module"): freeze_mods = ["module." + x for x in args.freeze_mods] else: freeze_mods = args.freeze_mods for mod, param in model.named_parameters(): if any(mod.startswith(key) for key in freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False model_params = filter(lambda x: x.requires_grad, model.parameters()) else: model_params = model.parameters() # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam( model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr ) elif args.opt == 'lamb': kw = dict(lr=0.1, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-6) from apex.optimizers import FusedAdam, FusedLAMB optimizer = FusedLAMB(model.parameters(), **kw) else: raise NotImplementedError("unknown optimizer: " + args.opt) if args.use_amp: opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) if args.amp_checkpoint is not None: logging.info("resumed from %s" % args.amp_checkpoint) checkpoint = torch.load(args.amp_checkpoint) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp']) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] num_batches = len(train_json.keys()) // args.batch_size use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" print(f'\n\n batch_sort_key: {args.batch_sort_key} \n\n') # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) from io_utils_fcl import LoadInputsAndTargets load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) converter = CustomConverter(reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset( train_batchset, lambda data: converter([load_tr(data)]) ), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset( valid_batchset, lambda data: converter([load_cv(data)]) ), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater( teacher_model, model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, args.use_amp, num_batches, args.outdir ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(teacher_model, model, valid_iter, reporter, device), trigger=eval_interval ) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger( "validation/main/loss", trigger=eval_interval ), ) # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) # if args.tensorboard_dir is not None and args.tensorboard_dir != "": # writer = SummaryWriter(args.tensorboard_dir) # trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def create_optimizer_v2(model_or_params, opt: str = 'sgd', lr: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, param_group_fn: Optional[Callable] = None, **kwargs): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. For more general use an interface that allows selection of parameters to optimize and lr groups, one of: * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion * expose the parameters interface and leave it up to caller Args: model_or_params (nn.Module): model containing parameters to optimize opt: name of optimizer to create lr: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through Returns: Optimizer """ if isinstance(model_or_params, nn.Module): # a model was passed in, extract parameters and add weight decays to appropriate layers no_weight_decay = {} if hasattr(model_or_params, 'no_weight_decay'): no_weight_decay = model_or_params.no_weight_decay() if param_group_fn: parameters = param_group_fn(model_or_params) elif layer_decay is not None: parameters = param_groups_layer_decay( model_or_params, weight_decay=weight_decay, layer_decay=layer_decay, no_weight_decay_list=no_weight_decay) weight_decay = 0. elif weight_decay and filter_bias_and_bn: parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) weight_decay = 0. else: parameters = model_or_params.parameters() else: # iterable of parameters or param groups passed in parameters = model_or_params opt_lower = opt.lower() opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(weight_decay=weight_decay, **kwargs) if lr is not None: opt_args.setdefault('lr', lr) # basic SGD & related if opt_lower == 'sgd' or opt_lower == 'nesterov': # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) # adaptive elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'nadam': try: # NOTE PyTorch >= 1.10 should have native NAdam optimizer = optim.Nadam(parameters, **opt_args) except AttributeError: optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamax': optimizer = optim.Adamax(parameters, **opt_args) elif opt_lower == 'adabelief': optimizer = AdaBelief(parameters, rectify=False, **opt_args) elif opt_lower == 'radabelief': optimizer = AdaBelief(parameters, rectify=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adagrad': opt_args.setdefault('eps', 1e-8) optimizer = optim.Adagrad(parameters, **opt_args) elif opt_lower == 'adafactor': optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) elif opt_lower == 'lambc': optimizer = Lamb(parameters, trust_clip=True, **opt_args) elif opt_lower == 'larc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) elif opt_lower == 'lars': optimizer = Lars(parameters, momentum=momentum, **opt_args) elif opt_lower == 'nlarc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) elif opt_lower == 'nlars': optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'madgrad': optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) elif opt_lower == 'madgradw': optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) # second order elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) # NVIDIA fused optimizers, require APEX to be installed elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def main(): parser = argparse.ArgumentParser(description='PyTorch FastPitch Training', allow_abbrev=False) parser = parse_args(parser) args, _ = parser.parse_known_args() distributed_run = args.world_size > 1 torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) if args.local_rank == 0: if not os.path.exists(args.output): os.makedirs(args.output) log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json') tb_subsets = ['train', 'val'] if args.ema_decay > 0.0: tb_subsets.append('val_ema') logger.init(log_fpath, args.output, enabled=(args.local_rank == 0), tb_subsets=tb_subsets) logger.parameters(vars(args), tb_subset='train') parser = models.parse_model_args('FastPitch', parser) args, unk_args = parser.parse_known_args() if len(unk_args) > 0: raise ValueError(f'Invalid options {unk_args}') torch.backends.cudnn.benchmark = args.cudnn_benchmark if distributed_run: init_distributed(args, args.world_size, args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') model_config = models.get_model_config('FastPitch', args) model = models.get_model('FastPitch', model_config, device) # Store pitch mean/std as params to translate from Hz during inference with open(args.pitch_mean_std_file, 'r') as f: stats = json.load(f) model.pitch_mean[0] = stats['mean'] model.pitch_std[0] = stats['std'] kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), **kw) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), **kw) else: raise ValueError scaler = torch.cuda.amp.GradScaler(enabled=args.amp) #if args.amp: #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") if args.ema_decay > 0: ema_model = copy.deepcopy(model) else: ema_model = None if distributed_run: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) start_epoch = [1] start_iter = [0] assert args.checkpoint_path is None or args.resume is False, ( "Specify a single checkpoint source") if args.checkpoint_path is not None: ch_fpath = args.checkpoint_path elif args.resume: ch_fpath = last_checkpoint(args.output) else: ch_fpath = None if ch_fpath is not None: load_checkpoint(args.local_rank, model, ema_model, optimizer, start_epoch, start_iter, model_config, args.amp, ch_fpath, args.world_size) start_epoch = start_epoch[0] total_iter = start_iter[0] criterion = loss_functions.get_loss_function( 'FastPitch', dur_predictor_loss_scale=args.dur_predictor_loss_scale, pitch_predictor_loss_scale=args.pitch_predictor_loss_scale) collate_fn = data_functions.get_collate_function('FastPitch') trainset = data_functions.get_data_loader('FastPitch', args.dataset_path, args.training_files, args) valset = data_functions.get_data_loader('FastPitch', args.dataset_path, args.validation_files, args) if distributed_run: train_sampler, shuffle = DistributedSampler(trainset), False else: train_sampler, shuffle = None, True train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle, sampler=train_sampler, batch_size=args.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn) batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch') model.train() torch.cuda.synchronize() for epoch in range(start_epoch, args.epochs + 1): epoch_start_time = time.perf_counter() epoch_loss = 0.0 epoch_mel_loss = 0.0 epoch_num_frames = 0 epoch_frames_per_sec = 0.0 if distributed_run: train_loader.sampler.set_epoch(epoch) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} epoch_iter = 0 num_iters = len(train_loader) // args.gradient_accumulation_steps for batch in train_loader: if accumulated_steps == 0: if epoch_iter == num_iters: break total_iter += 1 epoch_iter += 1 iter_start_time = time.perf_counter() adjust_learning_rate(total_iter, optimizer, args.learning_rate, args.warmup_steps) model.zero_grad() x, y, num_frames = batch_to_gpu(batch) #AMP upstream autocast with torch.cuda.amp.autocast(enabled=args.amp): y_pred = model(x, use_gt_durations=True) loss, meta = criterion(y_pred, y) loss /= args.gradient_accumulation_steps meta = { k: v / args.gradient_accumulation_steps for k, v in meta.items() } if args.amp: #with amp.scale_loss(loss, optimizer) as scaled_loss: #scaled_loss.backward() scaler.scale(loss).backward() else: loss.backward() if distributed_run: reduced_loss = reduce_tensor(loss.data, args.world_size).item() reduced_num_frames = reduce_tensor(num_frames.data, 1).item() meta = { k: reduce_tensor(v, args.world_size) for k, v in meta.items() } else: reduced_loss = loss.item() reduced_num_frames = num_frames.item() if np.isnan(reduced_loss): raise Exception("loss is NaN") accumulated_steps += 1 iter_loss += reduced_loss iter_num_frames += reduced_num_frames iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} if accumulated_steps % args.gradient_accumulation_steps == 0: logger.log_grads_tb(total_iter, model) if args.amp: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) scaler.step(optimizer) scaler.update() #optimizer.zero_grad(set_to_none=True) optimizer.zero_grad() else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) optimizer.step() apply_ema_decay(model, ema_model, args.ema_decay) iter_time = time.perf_counter() - iter_start_time iter_mel_loss = iter_meta['mel_loss'].item() epoch_frames_per_sec += iter_num_frames / iter_time epoch_loss += iter_loss epoch_num_frames += iter_num_frames epoch_mel_loss += iter_mel_loss logger.log( (epoch, epoch_iter, num_iters), tb_total_steps=total_iter, subset='train', data=OrderedDict([ ('loss', iter_loss), ('mel_loss', iter_mel_loss), ('frames/s', iter_num_frames / iter_time), ('took', iter_time), ('lrate', optimizer.param_groups[0]['lr']) ]), ) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} # Finished epoch epoch_time = time.perf_counter() - epoch_start_time logger.log( (epoch, ), tb_total_steps=None, subset='train_avg', data=OrderedDict([('loss', epoch_loss / epoch_iter), ('mel_loss', epoch_mel_loss / epoch_iter), ('frames/s', epoch_num_frames / epoch_time), ('took', epoch_time)]), ) validate(model, epoch, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True) if args.ema_decay > 0: validate(ema_model, epoch, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True, ema=True) if (epoch > 0 and args.epochs_per_checkpoint > 0 and (epoch % args.epochs_per_checkpoint == 0) and args.local_rank == 0): checkpoint_path = os.path.join(args.output, f"FastPitch_checkpoint_{epoch}.pt") save_checkpoint(args.local_rank, model, ema_model, optimizer, scaler, epoch, total_iter, model_config, args.amp, checkpoint_path) logger.flush() # Finished training logger.log( (), tb_total_steps=None, subset='train_avg', data=OrderedDict([('loss', epoch_loss / epoch_iter), ('mel_loss', epoch_mel_loss / epoch_iter), ('frames/s', epoch_num_frames / epoch_time), ('took', epoch_time)]), ) validate(model, None, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True) if (epoch > 0 and args.epochs_per_checkpoint > 0 and (epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0): checkpoint_path = os.path.join(args.output, f"FastPitch_checkpoint_{epoch}.pt") save_checkpoint(args.local_rank, model, ema_model, optimizer, scaler, epoch, total_iter, model_config, args.amp, checkpoint_path)
def run(config, args): if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ: local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) else: local_rank = args.rank world_size = args.world_size distributed_run = world_size > 1 torch.manual_seed(args.seed + local_rank) np.random.seed(args.seed + local_rank) # if local_rank == 0: # if not os.path.exists(args.output): # os.makedirs(args.output) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False if distributed_run: init_distributed(args, world_size, local_rank) device = torch.device('cuda' if args.cuda else 'cpu') if local_rank == 0: print("start training") print("args", args) print("config", config) ############################################# # model if local_rank == 0: print("load model") model = WaveGrad(config).cuda() my_schedule = model.set_new_noise_schedule compute_loss = model.compute_loss # if local_rank == 0: # print(model) # optimizer amp config if local_rank == 0: print("configure optimizer and amp") kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), **kw) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), **kw) elif args.optimizer == 'pytorch': optimizer = torch.optim.Adam(model.parameters(), **kw) else: raise ValueError if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") if distributed_run: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) start_epoch = [1] start_iter = [0] ################ #load checkpoint if args.checkpoint_path is not None: ch_fpath = args.checkpoint_path load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter, config, args.amp, ch_fpath, world_size) iteration = epoch * args.rank * args.batch_size if local_rank == 0: if (epoch % args.epochs_per_checkpoint == 0): ch_path = os.path.join(args.output, "WaveGrad_ch_{:d}.pt".format(epoch)) save_checkpoint(local_rank, model, optimizer, epoch, iteration, config, args.amp, ch_path) ch_path = os.path.join(args.output, "WaveGrad_model_ch_{:d}.pt".format(epoch)) save_checkpoint_modelonly(local_rank, model, epoch, iteration, config, ch_path)
def create_optimizer_v2( model: nn.Module, optimizer_name: str = 'sgd', learning_rate: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, **kwargs): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. For more general use an interface that allows selection of parameters to optimize and lr groups, one of: * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion * expose the parameters interface and leave it up to caller Args: model (nn.Module): model containing parameters to optimize optimizer_name: name of optimizer to create learning_rate: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through Returns: Optimizer """ opt_lower = optimizer_name.lower() if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adabelief': optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not learning_rate: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def prepare_model_and_optimizer(args, device): # Prepare model config = BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) model = BertForPreTraining(config) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if f.endswith(".pt") ] args.resume_step = max([ int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 if not args.init_checkpoint: checkpoint = torch.load(os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") else: checkpoint = torch.load(args.init_checkpoint, map_location="cpu") model.load_state_dict(checkpoint['model'], strict=False) if args.phase2: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale) amp._amp_state.loss_scalers[0]._loss_scale = 2**20 if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint['optimizer']['state'].keys()) #Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_step for iter, item in enumerate( checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][iter][ 'step'] = global_step checkpoint['optimizer']['param_groups'][iter][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][iter][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][iter][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint['optimizer']) for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP( model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0, )) elif args.n_gpu > 1: model = torch.nn.DataParallel(model) return model, optimizer, lr_scheduler, checkpoint, global_step
def create_optimizer_param(args, parameters): opt_lower = args.opt.lower() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=args.weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if hasattr(args, 'opt_args') and args.opt_args is not None: opt_args.update(args.opt_args) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not args.lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def train(model: nn.Module, loss_fn: _Loss, train_dataloader: DataLoader, val_dataloader: DataLoader, callbacks: List[BaseCallback], logger: Logger, args): device = torch.cuda.current_device() model.to(device=device) local_rank = get_local_rank() world_size = dist.get_world_size() if dist.is_initialized() else 1 if dist.is_initialized(): model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model._set_static_graph() model.train() grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 for callback in callbacks: callback.on_fit_start(optimizer, args) for epoch_idx in range(epoch_start, args.epochs): if isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch_idx) loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) if dist.is_initialized(): loss = torch.tensor(loss, dtype=torch.float, device=device) torch.distributed.all_reduce(loss) loss = (loss / world_size).item() logging.info(f'Train loss: {loss}') logger.log_metrics({'train loss': loss}, epoch_idx) for callback in callbacks: callback.on_epoch_end() if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ and (epoch_idx + 1) % args.ckpt_interval == 0: save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): evaluate(model, val_dataloader, callbacks, args) model.train() for callback in callbacks: callback.on_validation_end(epoch_idx) if args.save_ckpt_path is not None and not args.benchmark: save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) for callback in callbacks: callback.on_fit_end()
def main(): args = parse_args() assert (torch.cuda.is_available()) assert args.prediction_frequency % args.log_frequency == 0 torch.backends.cudnn.benchmark = args.cudnn_benchmark # set up distributed training multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1 if multi_gpu: torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') world_size = dist.get_world_size() print_once(f'Distributed training with {world_size} GPUs\n') else: world_size = 1 torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) random.seed(args.seed + args.local_rank) init_log(args) cfg = config.load(args.model_config) config.apply_config_overrides(cfg, args) symbols = helpers.add_ctc_blank(cfg['labels']) assert args.grad_accumulation >= 1 batch_size = args.gpu_batch_size print_once('Setting up datasets...') train_dataset_kw, train_features_kw = config.input(cfg, 'train') val_dataset_kw, val_features_kw = config.input(cfg, 'val') use_dali = args.dali_device in ('cpu', 'gpu') if use_dali: assert train_dataset_kw['ignore_offline_speed_perturbation'], \ "DALI doesn't support offline speed perturbation" # pad_to_max_duration is not supported by DALI - have simple padders if train_features_kw['pad_to_max_duration']: train_feat_proc = BaseFeatures( pad_align=train_features_kw['pad_align'], pad_to_max_duration=True, max_duration=train_features_kw['max_duration'], sample_rate=train_features_kw['sample_rate'], window_size=train_features_kw['window_size'], window_stride=train_features_kw['window_stride']) train_features_kw['pad_to_max_duration'] = False else: train_feat_proc = None if val_features_kw['pad_to_max_duration']: val_feat_proc = BaseFeatures( pad_align=val_features_kw['pad_align'], pad_to_max_duration=True, max_duration=val_features_kw['max_duration'], sample_rate=val_features_kw['sample_rate'], window_size=val_features_kw['window_size'], window_stride=val_features_kw['window_stride']) val_features_kw['pad_to_max_duration'] = False else: val_feat_proc = None train_loader = DaliDataLoader( gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=train_dataset_kw, config_features=train_features_kw, json_names=args.train_manifests, batch_size=batch_size, grad_accumulation_steps=args.grad_accumulation, pipeline_type="train", device_type=args.dali_device, symbols=symbols) val_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=val_dataset_kw, config_features=val_features_kw, json_names=args.val_manifests, batch_size=batch_size, pipeline_type="val", device_type=args.dali_device, symbols=symbols) else: train_dataset_kw, train_features_kw = config.input(cfg, 'train') train_dataset = AudioDataset(args.dataset_dir, args.train_manifests, symbols, **train_dataset_kw) train_loader = get_data_loader(train_dataset, batch_size, multi_gpu=multi_gpu, shuffle=True, num_workers=4) train_feat_proc = FilterbankFeatures(**train_features_kw) val_dataset_kw, val_features_kw = config.input(cfg, 'val') val_dataset = AudioDataset(args.dataset_dir, args.val_manifests, symbols, **val_dataset_kw) val_loader = get_data_loader(val_dataset, batch_size, multi_gpu=multi_gpu, shuffle=False, num_workers=4, drop_last=False) val_feat_proc = FilterbankFeatures(**val_features_kw) dur = train_dataset.duration / 3600 dur_f = train_dataset.duration_filtered / 3600 nsampl = len(train_dataset) print_once(f'Training samples: {nsampl} ({dur:.1f}h, ' f'filtered {dur_f:.1f}h)') if train_feat_proc is not None: train_feat_proc.cuda() if val_feat_proc is not None: val_feat_proc.cuda() steps_per_epoch = len(train_loader) // args.grad_accumulation # set up the model model = QuartzNet(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) model.cuda() ctc_loss = CTCLossNM(n_classes=len(symbols)) greedy_decoder = GreedyCTCDecoder() print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n') # optimization kw = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == "novograd": optimizer = Novograd(model.parameters(), **kw) elif args.optimizer == "adamw": optimizer = AdamW(model.parameters(), **kw) elif args.optimizer == 'lamb98': optimizer = FusedLAMB(model.parameters(), betas=(0.9, 0.98), eps=1e-9, **kw) elif args.optimizer == 'fused_novograd': optimizer = FusedNovoGrad(model.parameters(), betas=(0.95, 0), bias_correction=False, reg_inside_moment=True, grad_averaging=False, **kw) else: raise ValueError(f'Invalid optimizer "{args.optimizer}"') scaler = torch.cuda.amp.GradScaler(enabled=args.amp) adjust_lr = lambda step, epoch, optimizer: lr_policy( step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch, warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs, num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma) if args.ema > 0: ema_model = copy.deepcopy(model) else: ema_model = None if multi_gpu: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) if args.pyprof: pyprof.init(enable_function_stack=True) # load checkpoint meta = {'best_wer': 10**6, 'start_epoch': 0} checkpointer = Checkpointer(args.output_dir, 'QuartzNet', args.keep_milestones) if args.resume: args.ckpt = checkpointer.last_checkpoint() or args.ckpt if args.ckpt is not None: checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta) start_epoch = meta['start_epoch'] best_wer = meta['best_wer'] epoch = 1 step = start_epoch * steps_per_epoch + 1 if args.pyprof: torch.autograd.profiler.emit_nvtx().__enter__() profiler.start() # training loop model.train() if args.ema > 0.0: mt_ema_params = init_multi_tensor_ema(model, ema_model) # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_ # pre-allocate if args.pre_allocate_range is not None: n_feats = train_features_kw['n_filt'] pad_align = train_features_kw['pad_align'] a, b = args.pre_allocate_range for n_frames in range(a, b + pad_align, pad_align): print_once( f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...') feat = torch.randn(batch_size, n_feats, n_frames, device='cuda') feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames) txt = torch.randint(high=len(symbols) - 1, size=(batch_size, 100), device='cuda') txt_lens = torch.ones(batch_size, device='cuda').fill_(100) with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) del feat loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss.backward() model.zero_grad() torch.cuda.empty_cache() bmark_stats = BenchmarkStats() for epoch in range(start_epoch + 1, args.epochs + 1): if multi_gpu and not use_dali: train_loader.sampler.set_epoch(epoch) epoch_utts = 0 epoch_loss = 0 accumulated_batches = 0 epoch_start_time = time.time() epoch_eval_time = 0 for batch in train_loader: if accumulated_batches == 0: step_loss = 0 step_utts = 0 step_start_time = time.time() if use_dali: # with DALI, the data is already on GPU feat, feat_lens, txt, txt_lens = batch if train_feat_proc is not None: feat, feat_lens = train_feat_proc(feat, feat_lens) else: batch = [t.cuda(non_blocking=True) for t in batch] audio, audio_lens, txt, txt_lens = batch feat, feat_lens = train_feat_proc(audio, audio_lens) # Use context manager to prevent redundant accumulation of gradients if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation): ctx = model.no_sync() else: ctx = empty_context() with ctx: with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss /= args.grad_accumulation if multi_gpu: reduced_loss = reduce_tensor(loss.data, world_size) else: reduced_loss = loss if torch.isnan(reduced_loss).any(): print_once(f'WARNING: loss is NaN; skipping update') continue else: step_loss += reduced_loss.item() step_utts += batch[0].size(0) * world_size epoch_utts += batch[0].size(0) * world_size accumulated_batches += 1 scaler.scale(loss).backward() if accumulated_batches % args.grad_accumulation == 0: epoch_loss += step_loss scaler.step(optimizer) scaler.update() adjust_lr(step, epoch, optimizer) optimizer.zero_grad() if args.ema > 0.0: apply_multi_tensor_ema(args.ema, *mt_ema_params) if step % args.log_frequency == 0: preds = greedy_decoder(log_probs) wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols) if step % args.prediction_frequency == 0: print_once(f' Decoded: {pred_utt[:90]}') print_once(f' Reference: {ref[:90]}') step_time = time.time() - step_start_time log( (epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch), step, 'train', { 'loss': step_loss, 'wer': 100.0 * wer, 'throughput': step_utts / step_time, 'took': step_time, 'lrate': optimizer.param_groups[0]['lr'] }) step_start_time = time.time() if step % args.eval_frequency == 0: tik = time.time() wer = evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) if wer < best_wer and epoch >= args.save_best_from: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer, is_best=True) best_wer = wer epoch_eval_time += time.time() - tik step += 1 accumulated_batches = 0 # end of step # DALI iterator need to be exhausted; # if not using DALI, simulate drop_last=True with grad accumulation if not use_dali and step > steps_per_epoch * epoch: break epoch_time = time.time() - epoch_start_time epoch_loss /= steps_per_epoch log( (epoch, ), None, 'train_avg', { 'throughput': epoch_utts / epoch_time, 'took': epoch_time, 'loss': epoch_loss }) bmark_stats.update(epoch_utts, epoch_time, epoch_loss) if epoch % args.save_frequency == 0 or epoch in args.keep_milestones: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) if 0 < args.epochs_this_job <= epoch - start_epoch: print_once(f'Finished after {args.epochs_this_job} epochs.') break # end of epoch if args.pyprof: profiler.stop() torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num)) if epoch == args.epochs: evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) flush_log()
loaded = load_checkpoint('model_weights_best_epoch*pt') if not loaded: # if best doesn't exist, take the latest loaded = load_checkpoint('model_weights_epoch*pt') model = Bruno(config) if config.from_snapshot is not None: state_dicts = torch.load(config.from_snapshot) model.load_state_dict(state_dicts['model']) logger.info(f'Model ckpt {config.from_snapshot} loaded.') model = model.to(device) # opt = torch.optim.Adam(model.parameters(), lr=config.lr) opt = FusedLAMB(model.parameters(), lr=config.lr) if config.from_snapshot is not None: state_dicts = torch.load(config.from_snapshot) opt.load_state_dict(state_dicts['opt']) if config.lr_policy == 'exp' or config.lr_policy is None: lr = torch.optim.lr_scheduler.ExponentialLR(opt, config.lr_decay) elif config.lr_policy == 'cyclic': lr = torch.optim.lr_scheduler.CyclicLR( opt, 0, config.lr, step_size_up=steps_per_epoch * 2, scale_fn=partial(scale_fn, decay=config.lr_decay), cycle_momentum=False)
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] opt_args = dict(lr=args.lr, weight_decay=weight_decay) opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None and opt_lower not in [ 'sgd', 'momentum', 'fusedmomentum', 'fusedsgd' ]: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def main(): parser = argparse.ArgumentParser(description='PyTorch FastPitch Training', allow_abbrev=False) parser = parse_args(parser) args, _ = parser.parse_known_args() if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ: local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) else: local_rank = args.rank world_size = args.world_size distributed_run = world_size > 1 torch.manual_seed(args.seed + local_rank) np.random.seed(args.seed + local_rank) if local_rank == 0: if not os.path.exists(args.output): os.makedirs(args.output) log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json') log_fpath = unique_dllogger_fpath(log_fpath) init_dllogger(log_fpath) else: init_dllogger(dummy=True) [DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()] parser = models.parse_model_args('FastPitch', parser) args, unk_args = parser.parse_known_args() if len(unk_args) > 0: raise ValueError(f'Invalid options {unk_args}') torch.backends.cudnn.enabled = args.cudnn_enabled torch.backends.cudnn.benchmark = args.cudnn_benchmark if distributed_run: init_distributed(args, world_size, local_rank) device = torch.device('cuda' if args.cuda else 'cpu') model_config = models.get_model_config('FastPitch', args) model = models.get_model('FastPitch', model_config, device) # Store pitch mean/std as params to translate from Hz during inference fpath = common.utils.stats_filename(args.dataset_path, args.training_files, 'pitch_char') with open(args.pitch_mean_std_file, 'r') as f: stats = json.load(f) model.pitch_mean[0] = stats['mean'] model.pitch_std[0] = stats['std'] kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), **kw) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), **kw) else: raise ValueError if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") if args.ema_decay > 0: ema_model = copy.deepcopy(model) else: ema_model = None if distributed_run: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) start_epoch = [1] start_iter = [0] assert args.checkpoint_path is None or args.resume is False, ( "Specify a single checkpoint source") if args.checkpoint_path is not None: ch_fpath = args.checkpoint_path elif args.resume: ch_fpath = last_checkpoint(args.output) else: ch_fpath = None if ch_fpath is not None: load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch, start_iter, model_config, args.amp, ch_fpath, world_size) start_epoch = start_epoch[0] total_iter = start_iter[0] criterion = loss_functions.get_loss_function( 'FastPitch', dur_predictor_loss_scale=args.dur_predictor_loss_scale, pitch_predictor_loss_scale=args.pitch_predictor_loss_scale) collate_fn = data_functions.get_collate_function('FastPitch') trainset = data_functions.get_data_loader('FastPitch', args.dataset_path, args.training_files, args) valset = data_functions.get_data_loader('FastPitch', args.dataset_path, args.validation_files, args) if distributed_run: train_sampler, shuffle = DistributedSampler(trainset), False else: train_sampler, shuffle = None, True train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle, sampler=train_sampler, batch_size=args.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn) batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch') model.train() train_tblogger = TBLogger(local_rank, args.output, 'train') val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True) if args.ema_decay > 0: val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema') val_loss = 0.0 torch.cuda.synchronize() for epoch in range(start_epoch, args.epochs + 1): epoch_start_time = time.time() epoch_loss = 0.0 epoch_mel_loss = 0.0 epoch_num_frames = 0 epoch_frames_per_sec = 0.0 if distributed_run: train_loader.sampler.set_epoch(epoch) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} epoch_iter = 0 num_iters = len(train_loader) // args.gradient_accumulation_steps for batch in train_loader: if accumulated_steps == 0: if epoch_iter == num_iters: break total_iter += 1 epoch_iter += 1 iter_start_time = time.time() start = time.perf_counter() old_lr = optimizer.param_groups[0]['lr'] adjust_learning_rate(total_iter, optimizer, args.learning_rate, args.warmup_steps) new_lr = optimizer.param_groups[0]['lr'] if new_lr != old_lr: dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}' train_tblogger.log_value(total_iter, 'lrate', new_lr) else: dllog_lrate_change = None model.zero_grad() x, y, num_frames = batch_to_gpu(batch) y_pred = model(x, use_gt_durations=True) loss, meta = criterion(y_pred, y) loss /= args.gradient_accumulation_steps meta = { k: v / args.gradient_accumulation_steps for k, v in meta.items() } if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if distributed_run: reduced_loss = reduce_tensor(loss.data, world_size).item() reduced_num_frames = reduce_tensor(num_frames.data, 1).item() meta = { k: reduce_tensor(v, world_size) for k, v in meta.items() } else: reduced_loss = loss.item() reduced_num_frames = num_frames.item() if np.isnan(reduced_loss): raise Exception("loss is NaN") accumulated_steps += 1 iter_loss += reduced_loss iter_num_frames += reduced_num_frames iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} if accumulated_steps % args.gradient_accumulation_steps == 0: train_tblogger.log_grads(total_iter, model) if args.amp: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.grad_clip_thresh) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) optimizer.step() apply_ema_decay(model, ema_model, args.ema_decay) iter_stop_time = time.time() iter_time = iter_stop_time - iter_start_time frames_per_sec = iter_num_frames / iter_time epoch_frames_per_sec += frames_per_sec epoch_loss += iter_loss epoch_num_frames += iter_num_frames iter_mel_loss = iter_meta['mel_loss'].item() epoch_mel_loss += iter_mel_loss DLLogger.log( (epoch, epoch_iter, num_iters), OrderedDict([('train_loss', iter_loss), ('train_mel_loss', iter_mel_loss), ('train_frames/s', frames_per_sec), ('took', iter_time), ('lrate_change', dllog_lrate_change)])) train_tblogger.log_meta(total_iter, iter_meta) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} # Finished epoch epoch_stop_time = time.time() epoch_time = epoch_stop_time - epoch_start_time DLLogger.log((epoch, ), data=OrderedDict([ ('avg_train_loss', epoch_loss / epoch_iter), ('avg_train_mel_loss', epoch_mel_loss / epoch_iter), ('avg_train_frames/s', epoch_num_frames / epoch_time), ('took', epoch_time) ])) tik = time.time() val_loss, meta, num_frames = validate(model, criterion, valset, args.batch_size, world_size, collate_fn, distributed_run, local_rank, batch_to_gpu, use_gt_durations=True) tok = time.time() DLLogger.log((epoch, ), data=OrderedDict([ ('val_loss', val_loss), ('val_mel_loss', meta['mel_loss'].item()), ('val_frames/s', num_frames / (tok - tik)), ('took', tok - tik), ])) val_tblogger.log_meta(total_iter, meta) if args.ema_decay > 0: tik_e = time.time() val_loss_e, meta_e, num_frames_e = validate(ema_model, criterion, valset, args.batch_size, world_size, collate_fn, distributed_run, local_rank, batch_to_gpu, use_gt_durations=True) tok_e = time.time() DLLogger.log( (epoch, ), data=OrderedDict([ ('val_ema_loss', val_loss_e), ('val_ema_mel_loss', meta_e['mel_loss'].item()), ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)), ('took', tok_e - tik_e), ])) val_ema_tblogger.log_meta(total_iter, meta) if (epoch > 0 and args.epochs_per_checkpoint > 0 and (epoch % args.epochs_per_checkpoint == 0) and local_rank == 0): checkpoint_path = os.path.join(args.output, f"FastPitch_checkpoint_{epoch}.pt") save_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter, model_config, args.amp, checkpoint_path) if local_rank == 0: DLLogger.flush() # Finished training DLLogger.log((), data=OrderedDict([ ('avg_train_loss', epoch_loss / epoch_iter), ('avg_train_mel_loss', epoch_mel_loss / epoch_iter), ('avg_train_frames/s', epoch_num_frames / epoch_time), ])) DLLogger.log((), data=OrderedDict([ ('val_loss', val_loss), ('val_mel_loss', meta['mel_loss'].item()), ('val_frames/s', num_frames / (tok - tik)), ])) if local_rank == 0: DLLogger.flush()
def prepare_model_and_optimizer(args, device): # Prepare model config = modeling.BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training model = modeling.BertForPreTraining(config) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if f.endswith(".pt") ] args.resume_step = max([ int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 if not args.init_checkpoint: checkpoint = torch.load(os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") else: checkpoint = torch.load(args.init_checkpoint, map_location="cpu") model.load_state_dict(checkpoint['model'], strict=False) if args.phase2 and not args.init_checkpoint: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) # BERT modeling uses weight sharing between word embedding and prediction decoder. # So make sure the storage is pointing properly even after model is moved to device. if args.use_habana: model.cls.predictions.decoder.weight = model.bert.embeddings.word_embeddings.weight param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.use_habana: if args.use_fused_lamb: try: from hb_custom import FusedLamb except ImportError: raise ImportError("Please install hbopt.") optimizer = FusedLamb(optimizer_grouped_parameters, lr=args.learning_rate) else: optimizer = NVLAMB(optimizer_grouped_parameters, lr=args.learning_rate) else: if torch.cuda.is_available(): optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) else: optimizer = NVLAMB(optimizer_grouped_parameters, lr=args.learning_rate) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16) else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16) amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale model.checkpoint_activations(args.checkpoint_activations) if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint['optimizer']['state'].keys()) #Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_step for iter, item in enumerate( checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][iter][ 'step'] = global_step checkpoint['optimizer']['param_groups'][iter][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][iter][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][iter][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint['optimizer']) for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: if not args.use_jit_trace: if args.use_habana: model = DDP(model) else: model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0, )) elif args.n_pu > 1: model = torch.nn.DataParallel(model) criterion = BertPretrainingCriterion(config.vocab_size) return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def get_optimizer(optimizer_name: str, parameters, learning_rate: float, weight_decay=1e-5, eps=1e-5, **kwargs) -> Optimizer: from torch.optim import SGD, Adam, RMSprop, AdamW from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger if optimizer_name.lower() == "sgd": return SGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "adam": return Adam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) # As Jeremy suggests if optimizer_name.lower() == "rms": return RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "adamw": return AdamW(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) if optimizer_name.lower() == "radam": return RAdam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) # As Jeremy suggests # Optimizers from torch-optimizer if optimizer_name.lower() == "ranger": return Ranger(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "lamb": return Lamb(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "diffgrad": return DiffGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "novograd": return NovoGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) # Optimizers from Apex (Fused version is faster on GPU with tensor cores) if optimizer_name.lower() == "fused_lamb": from apex.optimizers import FusedLAMB return FusedLAMB(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "fused_sgd": from apex.optimizers import FusedSGD return FusedSGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs) if optimizer_name.lower() == "fused_adam": from apex.optimizers import FusedAdam return FusedAdam(parameters, learning_rate, eps=eps, weight_decay=weight_decay, adam_w_mode=True, **kwargs) raise ValueError("Unsupported optimizer name " + optimizer_name)
def create_optimizer(args, model, filter_bias_and_bn=True, freeze_stage=""): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: if freeze_stage == "stage1": stage1_train_attn(model, layer_names=['fc']) print('stage1, Freeze layer successfully') if freeze_stage == "stage2": stage1_train_attn(model, layer_names=['layer3', 'layer4', 'se', 'fc']) stage2_train_layer4(model) print('stage2, Freeze layer successfully') # 对未冻结的层进行权重衰减 parameters = add_weight_decay(model, weight_decay) weight_decay = 0. else: parameters = model.parameters() for name, param in model.named_parameters(): print(name, param.requires_grad) if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamp': optimizer = AdamP(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, delta=0.1, wd_ratio=0.01, nesterov=True) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, eps=args.opt_eps, nesterov=True) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def get_optimizer( model: nn.Module, optimizer_name: str, learning_rate: float, weight_decay: float = 1e-5, no_weight_decay_on_bias: bool = False, eps: float = 1e-5, **kwargs, ) -> Optimizer: """ Construct an Optimizer for given model Args: model: Model to optimize. Only parameters that require_grad will be used optimizer_name: Name of the optimizer. Case-insensitive learning_rate: Target learning rate (regardless of the scheduler) weight_decay: Target weight decay no_weight_decay_on_bias: Whether to disable weight decay on bias parameters eps: Default epsilon for Adam-like optimizers. **kwargs: Additional parameters for optimizer Returns: """ from torch.optim import ASGD, SGD, Adam, RMSprop, AdamW from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger # Optimizer parameter groups default_pg, biases_pg = [], [] for k, v in model.named_parameters(): if v.requires_grad: if str.endswith(k, ".bias"): biases_pg.append(v) # biases else: default_pg.append(v) # all else if no_weight_decay_on_bias: parameters = default_pg else: parameters = default_pg + biases_pg optimizer: Optimizer = None if optimizer_name.lower() == "sgd": optimizer = SGD( parameters, lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "asgd": optimizer = ASGD( parameters, lr=learning_rate, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "adam": optimizer = Adam( parameters, lr=learning_rate, weight_decay=weight_decay, eps=eps, **kwargs, ) elif optimizer_name.lower() == "rms": optimizer = RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs) elif optimizer_name.lower() == "adamw": optimizer = AdamW( parameters, lr=learning_rate, weight_decay=weight_decay, eps=eps, **kwargs, ) elif optimizer_name.lower() == "radam": optimizer = RAdam( parameters, lr=learning_rate, weight_decay=weight_decay, eps=eps, **kwargs, ) elif optimizer_name.lower() == "ranger": optimizer = Ranger( parameters, lr=learning_rate, eps=eps, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "lamb": optimizer = Lamb( parameters, lr=learning_rate, eps=eps, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "diffgrad": optimizer = DiffGrad( parameters, lr=learning_rate, eps=eps, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "novograd": optimizer = NovoGrad( parameters, lr=learning_rate, eps=eps, weight_decay=weight_decay, **kwargs, ) elif optimizer_name.lower() == "fused_lamb": from apex.optimizers import FusedLAMB optimizer = FusedLAMB(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) elif optimizer_name.lower() == "fused_sgd": from apex.optimizers import FusedSGD optimizer = FusedSGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs) elif optimizer_name.lower() == "fused_adam": from apex.optimizers import FusedAdam optimizer = FusedAdam(parameters, learning_rate, eps=eps, weight_decay=weight_decay, adam_w_mode=True, **kwargs) else: raise KeyError(f"Cannot get optimizer by name {optimizer_name}") # Currently either no_wd or per-group lr if no_weight_decay_on_bias: optimizer.add_param_group({"params": biases_pg, "weight_decay": 0}) return optimizer
def prepare_optimizers(args, model, checkpoint, global_steps): param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.lr_decay == 'poly': Scheduler = PolyWarmUpScheduler elif args.lr_decay == 'linear': Scheduler = LinearWarmUpScheduler else: raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay)) optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if checkpoint is not None: if args.resume_step >= args.previous_phase_end_step: keys = list(checkpoint['optimizer']['state'].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_steps for i, item in enumerate(checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][i][ 'step'] = global_steps checkpoint['optimizer']['param_groups'][i][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][i][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][i][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) lr_schedulers = [ Scheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) ] scaler = None if args.fp16: scaler = GradScaler() if checkpoint is not None and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) preconditioner = None if args.kfac: preconditioner = kfac.KFAC( model, lr=args.learning_rate, factor_decay=args.kfac_stat_decay, damping=args.kfac_damping, kl_clip=args.kfac_kl_clip, factor_update_freq=args.kfac_factor_interval, inv_update_freq=args.kfac_inv_interval, # Skip TrainingHeads which contains the decoder, a Linear module # with shape (seq_len, vocab_size), such that it is too large to invert skip_layers=args.kfac_skip_layers, # BERT calls KFAC very infrequently so no need to optimize for # communication. Optimize for memory instead. comm_method=kfac.CommMethod.HYBRID_OPT, grad_worker_fraction=0.5, inv_dtype=torch.float16, # Compute the factors and update the running averages during the # forward backward pass b/c we are using grad accumulation but # not accumulating the input/output data accumulate_data=False, compute_factor_in_hook=True, distribute_layer_factors=False, grad_scaler=scaler, ) lrs = Scheduler(preconditioner, warmup=args.warmup_proportion, total_steps=args.max_steps) lr_schedulers.append(lrs) if checkpoint is not None and 'preconditioner' in checkpoint: preconditioner.load_state_dict(checkpoint['preconditioner']) if is_main_process(): logger.info(preconditioner) return optimizer, preconditioner, lr_schedulers, scaler
def create_optimizer(args, model, filter_bias_and_bn=True, classification_layer_name=None): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: # batch norm and bias params if classification_layer_name is not None: parameters = set_lr_per_params(args, model, classification_layer_name, weight_decay) else: parameters = add_weight_decay(model, weight_decay) weight_decay = 0. # reset to 0 else: if classification_layer_name is not None: parameters = set_lr_per_params(args, model, classification_layer_name, weight_decay=0) else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamp': optimizer = AdamP(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, delta=0.1, wd_ratio=0.01, nesterov=True) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, eps=args.opt_eps, nesterov=True) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def main(): parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments)) training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, ) # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): if is_main_process(training_args.local_rank): transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() logger.info("Training/evaluation parameters %s", training_args) # Set seed before initializing model. set_seed(training_args.seed) config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir) tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir) # find latest checkpoint in output_dir output_dir = Path(training_args.output_dir) logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}') latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime) if latest_checkpoint_dir is not None: logger.info(f'Loading model from {latest_checkpoint_dir}') model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir) else: logger.info(f'Training from scratch') model = AlbertForPreTraining(config) model.resize_token_embeddings(len(tokenizer)) tokenized_dataset_path = Path(dataset_args.dataset_path) tokenized_datasets = load_from_disk(tokenized_dataset_path) # Data collator # This one will take care of randomly masking the tokens. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": training_args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = FusedLAMB( optimizer_grouped_parameters, lr=training_args.learning_rate, betas=(training_args.adam_beta1, training_args.adam_beta2), eps=training_args.adam_epsilon, ) lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps ) trainer = CollaborativeTrainer( model=model, args=training_args, collaboration_args=collaboration_args, train_dataset=tokenized_datasets["train"] if training_args.do_train else None, eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, optimizers=(optimizer, lr_scheduler) ) # Training if training_args.do_train: trainer.train(model_path=latest_checkpoint_dir)
def prepare_model_and_optimizer(args, device): global_step = 0 args.resume_step = 0 checkpoint = None config = BertConfig.from_json_file(args.bert_config_path) config.fused_mha = args.fused_mha config.fused_gelu_bias = args.fused_gelu_bias config.dense_seq_output = args.dense_seq_output config.unpad = args.unpad config.pad = args.pad config.fuse_qkv = not args.disable_fuse_qkv config.fuse_scale = not args.disable_fuse_scale config.fuse_mask = not args.disable_fuse_mask config.fuse_dropout = args.enable_fuse_dropout config.apex_softmax = not args.disable_apex_softmax config.enable_stream = args.enable_stream if config.fuse_mask == True: config.apex_softmax = True if config.pad == False: config.enable_stream = True if config.unpad == True: config.fused_mha = False # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) # Load from Pyt checkpoint - either given as init_checkpoint, or picked up from output_dir if found if args.init_checkpoint is not None or found_resume_checkpoint(args): # Prepare model model = BertForPreTraining(config) if args.init_checkpoint is None: # finding checkpoint in output_dir checkpoint_str = "phase2_ckpt_*.pt" if args.phase2 else "phase1_ckpt_*.pt" model_names = [f for f in glob.glob(os.path.join(args.output_dir, checkpoint_str))] global_step = max([int(x.split('.pt')[0].split('_')[-1].strip()) for x in model_names]) args.resume_step = global_step #used for throughput computation resume_init_checkpoint = os.path.join(args.output_dir, checkpoint_str.replace("*", str(global_step))) print("Setting init checkpoint to %s - which is the latest in %s" %(resume_init_checkpoint, args.output_dir)) checkpoint=torch.load(resume_init_checkpoint, map_location="cpu") else: checkpoint=torch.load(args.init_checkpoint, map_location="cpu")["model"] # Fused MHA requires a remapping of checkpoint parameters if config.fused_mha: checkpoint_remapped = remap_attn_parameters(checkpoint) model.load_state_dict(checkpoint_remapped, strict=False) else: model.load_state_dict(checkpoint, strict=True) else: #Load from TF Checkpoint model = BertForPreTraining.from_pretrained(args.init_tf_checkpoint, from_tf=True, config=config) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay_rate}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate, sync=False) optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate, betas=(args.opt_lamb_beta_1, args.opt_lamb_beta_2)) mlperf_logger.log_event(key='opt_epsilon', value=optimizer.defaults['eps'], sync=False) b1, b2 = optimizer.defaults['betas'] mlperf_logger.log_event(key='opt_lamb_beta_1', value=b1, sync=False) mlperf_logger.log_event(key='opt_lamb_beta_2', value=b2, sync=False) mlperf_logger.log_event(key='opt_lamb_weight_decay_rate', value=optimizer.defaults['weight_decay'], sync=False) if args.warmup_steps == 0: warmup_steps = int(args.max_steps * args.warmup_proportion) warmup_start = 0 else: warmup_steps = args.warmup_steps warmup_start = args.start_warmup_step lr_scheduler = LinearWarmupPolyDecayScheduler(optimizer, start_warmup_steps=warmup_start, warmup_steps=warmup_steps, total_steps=args.max_steps, end_learning_rate=0.0, degree=1.0) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale) amp._amp_state.loss_scalers[0]._loss_scale = float(os.getenv("INIT_LOSS_SCALE", 2**20)) if found_resume_checkpoint(args): optimizer.load_state_dict(checkpoint['optimizer']) #restores m,v states (only if resuming checkpoint, not for init_checkpoint and init_tf_checkpoint for now) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint['optimizer']) for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) return model, optimizer, lr_scheduler, checkpoint, global_step
def main(): parser = argparse.ArgumentParser(description='PyTorch FastPitch Training', allow_abbrev=False) parser = parse_args(parser) args, _ = parser.parse_known_args() if args.p_arpabet > 0.0: cmudict.initialize(args.cmudict_path, keep_ambiguous=True) distributed_run = args.world_size > 1 torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) if args.local_rank == 0: if not os.path.exists(args.output): os.makedirs(args.output) log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json') tb_subsets = ['train', 'val'] if args.ema_decay > 0.0: tb_subsets.append('val_ema') logger.init(log_fpath, args.output, enabled=(args.local_rank == 0), tb_subsets=tb_subsets) logger.parameters(vars(args), tb_subset='train') parser = models.parse_model_args('FastPitch', parser) args, unk_args = parser.parse_known_args() if len(unk_args) > 0: raise ValueError(f'Invalid options {unk_args}') torch.backends.cudnn.benchmark = args.cudnn_benchmark if distributed_run: init_distributed(args, args.world_size, args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') model_config = models.get_model_config('FastPitch', args) model = models.get_model('FastPitch', model_config, device) attention_kl_loss = AttentionBinarizationLoss() # Store pitch mean/std as params to translate from Hz during inference model.pitch_mean[0] = args.pitch_mean model.pitch_std[0] = args.pitch_std kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), **kw) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), **kw) else: raise ValueError scaler = torch.cuda.amp.GradScaler(enabled=args.amp) if args.ema_decay > 0: ema_model = copy.deepcopy(model) else: ema_model = None if distributed_run: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) start_epoch = [1] start_iter = [0] assert args.checkpoint_path is None or args.resume is False, ( "Specify a single checkpoint source") if args.checkpoint_path is not None: ch_fpath = args.checkpoint_path elif args.resume: ch_fpath = last_checkpoint(args.output) else: ch_fpath = None if ch_fpath is not None: load_checkpoint(args, model, ema_model, optimizer, scaler, start_epoch, start_iter, model_config, ch_fpath) start_epoch = start_epoch[0] total_iter = start_iter[0] criterion = FastPitchLoss( dur_predictor_loss_scale=args.dur_predictor_loss_scale, pitch_predictor_loss_scale=args.pitch_predictor_loss_scale, attn_loss_scale=args.attn_loss_scale) collate_fn = TTSCollate() if args.local_rank == 0: prepare_tmp(args.pitch_online_dir) trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args)) valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args)) if distributed_run: train_sampler, shuffle = DistributedSampler(trainset), False else: train_sampler, shuffle = None, True # 4 workers are optimal on DGX-1 (from epoch 2 onwards) train_loader = DataLoader(trainset, num_workers=4, shuffle=shuffle, sampler=train_sampler, batch_size=args.batch_size, pin_memory=True, persistent_workers=True, drop_last=True, collate_fn=collate_fn) if args.ema_decay: mt_ema_params = init_multi_tensor_ema(model, ema_model) model.train() bmark_stats = BenchmarkStats() torch.cuda.synchronize() for epoch in range(start_epoch, args.epochs + 1): epoch_start_time = time.perf_counter() epoch_loss = 0.0 epoch_mel_loss = 0.0 epoch_num_frames = 0 epoch_frames_per_sec = 0.0 if distributed_run: train_loader.sampler.set_epoch(epoch) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} iter_start_time = time.perf_counter() epoch_iter = 0 num_iters = len(train_loader) // args.grad_accumulation for batch in train_loader: if accumulated_steps == 0: if epoch_iter == num_iters: break total_iter += 1 epoch_iter += 1 adjust_learning_rate(total_iter, optimizer, args.learning_rate, args.warmup_steps) model.zero_grad(set_to_none=True) x, y, num_frames = batch_to_gpu(batch) with torch.cuda.amp.autocast(enabled=args.amp): y_pred = model(x) loss, meta = criterion(y_pred, y) if (args.kl_loss_start_epoch is not None and epoch >= args.kl_loss_start_epoch): if args.kl_loss_start_epoch == epoch and epoch_iter == 1: print('Begin hard_attn loss') _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred binarization_loss = attention_kl_loss(attn_hard, attn_soft) kl_weight = min( (epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight meta['kl_loss'] = binarization_loss.clone().detach( ) * kl_weight loss += kl_weight * binarization_loss else: meta['kl_loss'] = torch.zeros_like(loss) kl_weight = 0 binarization_loss = 0 loss /= args.grad_accumulation meta = {k: v / args.grad_accumulation for k, v in meta.items()} if args.amp: scaler.scale(loss).backward() else: loss.backward() if distributed_run: reduced_loss = reduce_tensor(loss.data, args.world_size).item() reduced_num_frames = reduce_tensor(num_frames.data, 1).item() meta = { k: reduce_tensor(v, args.world_size) for k, v in meta.items() } else: reduced_loss = loss.item() reduced_num_frames = num_frames.item() if np.isnan(reduced_loss): raise Exception("loss is NaN") accumulated_steps += 1 iter_loss += reduced_loss iter_num_frames += reduced_num_frames iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} if accumulated_steps % args.grad_accumulation == 0: logger.log_grads_tb(total_iter, model) if args.amp: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) optimizer.step() if args.ema_decay > 0.0: apply_multi_tensor_ema(args.ema_decay, *mt_ema_params) iter_mel_loss = iter_meta['mel_loss'].item() iter_kl_loss = iter_meta['kl_loss'].item() iter_time = time.perf_counter() - iter_start_time epoch_frames_per_sec += iter_num_frames / iter_time epoch_loss += iter_loss epoch_num_frames += iter_num_frames epoch_mel_loss += iter_mel_loss log( (epoch, epoch_iter, num_iters), tb_total_steps=total_iter, subset='train', data=OrderedDict([ ('loss', iter_loss), ('mel_loss', iter_mel_loss), ('kl_loss', iter_kl_loss), ('kl_weight', kl_weight), ('frames/s', iter_num_frames / iter_time), ('took', iter_time), ('lrate', optimizer.param_groups[0]['lr']) ]), ) accumulated_steps = 0 iter_loss = 0 iter_num_frames = 0 iter_meta = {} iter_start_time = time.perf_counter() # Finished epoch epoch_loss /= epoch_iter epoch_mel_loss /= epoch_iter epoch_time = time.perf_counter() - epoch_start_time log( (epoch, ), tb_total_steps=None, subset='train_avg', data=OrderedDict([('loss', epoch_loss), ('mel_loss', epoch_mel_loss), ('frames/s', epoch_num_frames / epoch_time), ('took', epoch_time)]), ) bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, epoch_time) validate(model, epoch, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu) if args.ema_decay > 0: validate(ema_model, epoch, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu, ema=True) maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch, total_iter, model_config) logger.flush() # Finished training if len(bmark_stats) > 0: log((), tb_total_steps=None, subset='train_avg', data=bmark_stats.get(args.benchmark_epochs_num)) validate(model, None, total_iter, criterion, valset, args.batch_size, collate_fn, distributed_run, batch_to_gpu)
def run(config, args): if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ: local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) else: local_rank = args.rank world_size = args.world_size distributed_run = world_size > 1 torch.manual_seed(args.seed + local_rank) np.random.seed(args.seed + local_rank) # if local_rank == 0: # if not os.path.exists(args.output): # os.makedirs(args.output) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False if distributed_run: init_distributed(args, world_size, local_rank) device = torch.device('cuda' if args.cuda else 'cpu') if local_rank == 0: print("start training") print("args", args) print("config", config) ############################################# # model if local_rank == 0: print("load model") model = WaveGrad(config).cuda() # optimizer amp config if local_rank == 0: print("configure optimizer and amp") kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), **kw) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), **kw) elif args.optimizer == 'pytorch': optimizer = torch.optim.Adam(model.parameters(), **kw) else: raise ValueError if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") if distributed_run: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) start_epoch = [1] start_iter = [0] ################ #load checkpoint if args.checkpoint_path is not None: ch_fpath = args.checkpoint_path load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter, config, args.amp, ch_fpath, world_size) start_epoch = start_epoch[0] total_iter = start_iter[0] # dataloader ########################################################## if local_rank == 0: print("load dataset") if local_rank == 0: print("prepare train dataset") train_dataset = AudioDataset(config, training=True) # distributed sampler if distributed_run: train_sampler, shuffle = DistributedSampler(train_dataset), False else: train_sampler, shuffle = None, True train_loader = DataLoader(train_dataset, num_workers=1, shuffle=shuffle, sampler=train_sampler, batch_size=args.batch_size, pin_memory=False, drop_last=True) # ground truth samples if local_rank == 0: print("prepare test_dataset") test_dataset = AudioDataset(config, training=False) test_loader = DataLoader(test_dataset, batch_size=1) test_batch = test_dataset.sample_test_batch( config.training_config.n_samples_to_test) # Log ground truth test batch if local_rank == 0: print("save truth wave and mel") mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate, n_fft=config.data_config.n_fft, win_length=config.data_config.win_length, hop_length=config.data_config.hop_length, f_min=config.data_config.f_min, f_max=config.data_config.f_max, n_mels=config.data_config.n_mels, window_fn=torch.hann_window).cuda() audios = { f'audio_{index}/gt': audio for index, audio in enumerate(test_batch) } specs = { f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze() for index, audio in enumerate(test_batch) } ####### loop start #epoch iteration = 0 model.train() val_loss = 0.0 torch.cuda.synchronize() if local_rank == 0: print("epoch start") for epoch in range(start_epoch, args.epochs + 1): tic_epoch = time.time() epoch_loss = 0.0 if distributed_run: train_loader.sampler.set_epoch(epoch) accumulated_steps = 0 iter_loss = 0 epoch_iter = 0 #iteration = 0 num_iters = len(train_loader) // args.gradient_accumulation_steps model.module.set_new_noise_schedule( # 1000 default init=torch.linspace, init_kwargs={ 'steps': config.training_config.training_noise_schedule.n_iter, 'start': config.training_config.training_noise_schedule.betas_range[0], 'end': config.training_config.training_noise_schedule.betas_range[1] } ) for i, batch in enumerate(train_loader): tic_iter = time.time() old_lr = optimizer.param_groups[0]['lr'] adjust_learning_rate(iteration, optimizer, args.learning_rate, args.warmup_steps) new_lr = optimizer.param_groups[0]['lr'] model.zero_grad() batch = batch.cuda() mels = mel_fn(batch) # Training step model.zero_grad() loss = model.module.compute_loss(mels, batch) if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if distributed_run: reduced_loss = reduce_tensor(loss.data, world_size).item() else: reduced_loss = loss.item() # if np.isnan(reduced_loss): # raise Exception("loss is NaN") iter_loss += reduced_loss if args.amp: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.grad_clip_thresh) else: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.grad_clip_thresh) optimizer.step() toc_iter = time.time() dur_iter = toc_iter - tic_iter epoch_loss += iter_loss iter_size = len(train_loader) dur_epoch_est = iter_size * dur_iter if local_rank == 0: print( "\nepoch {:4d} | iter {:>12d} {:>3d}/{:3d} | {:3.2f}s/iter est {:4.2f}s/epoch | losses {:>12.6f} {:>12.6f} LR {:e}--> {:e}" .format(epoch, iteration, i, iter_size, dur_iter, dur_epoch_est, iter_loss, grad_norm, old_lr, new_lr), end='') iter_loss = 0 iteration += 1 # Finished epoch toc_epoch = time.time() dur_epoch = toc_epoch - tic_epoch if local_rank == 0: print("for {}item, {:4.2f}s/epoch ".format( iter_size, dur_epoch)) # Test step if epoch % config.training_config.test_interval == 0: model.module.set_new_noise_schedule( # 50 for default init=torch.linspace, init_kwargs={ 'steps': config.training_config.test_noise_schedule.n_iter, 'start': config.training_config.test_noise_schedule.betas_range[0], 'end': config.training_config.test_noise_schedule.betas_range[1] } ) if (epoch % args.epochs_per_checkpoint == 0): ch_path = os.path.join(args.output, "WaveGrad_ch_{:d}.pt".format(epoch)) save_checkpoint(local_rank, model, optimizer, epoch, iteration, config, args.amp, ch_path)