def init_training(self): model = self.elements["model"] start_epoch = self.params["start_epoch"] exist_model = self.params["exist_model"] model_dir = self.params["model_dir"] model_blueprint = self.params["model_blueprint"] suffix = self.params["suffix"] if start_epoch <= 0 and utils.is_main_training(): model_creation = model.get_model_creation() utils.write_nnet_config(model_blueprint, model_creation, "{0}/config/nnet.config".format(model_dir)) ## Recover checkpoint | Tansform learning | Initialize parametes if start_epoch > 0: # This train_stage is equal to number of completed epoch if utils.is_main_training(): logger.info("Recover training from {0} epoch.".format(start_epoch)) model.load_state_dict(torch.load('{0}/{1}.{2}'.format(model_dir, start_epoch, suffix), map_location="cpu")) elif os.path.exists(exist_model): if utils.is_main_training(): logger.info("Use {0} as the initial model to start transform-training.".format(exist_model)) model.load_transform_state_dict(torch.load(exist_model, map_location="cpu")) else: # Just use the raw initial model or initialize it again by some initial functions here pass # Now, it means use the raw initial model if utils.use_horovod(): import horovod.torch as hvd # Broadcast parameters from rank 0 to all other processes. hvd.broadcast_parameters(self.elements["model"].state_dict(), root_rank=0) # For optimizer wrapper such as lookahead. if getattr(self.elements["optimizer"], "optimizer", None) is not None: raise TypeError("Do not support using lookahead with horovod now.") else: # Broadcast optimizer state. hvd.broadcast_optimizer_state(self.elements["optimizer"], root_rank=0) self.elements["optimizer"] = hvd.DistributedOptimizer(self.elements["optimizer"], named_parameters=self.elements["model"].named_parameters()) ## Select device model = self.select_device() # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after # wrapped by DistributedDataParallel. So, to call functions of TopVirtualNnet conveniently, the # self.elements["model_forward"] is set here to name DistributedDataParallel. if isinstance(model, torch.nn.parallel.DistributedDataParallel): self.elements["model"] = model.module self.elements["model_forward"] = model
def run_lr_finder(self, save_file: str, comment=None, init_lr=1e-8, final_lr=10., num_iters=None, beta=0.98): self.init_training() log_dir = self.params["model_dir"] + "/log/" # For tensorboardX if comment is not None: save_file = comment + "-" + save_file save_file = log_dir + save_file log_lrs, values_matrix = self.lr_finder_compute( self.elements["data"].train_loader, self.elements["optimizer"], init_lr=init_lr, final_lr=final_lr, num_iters=num_iters, beta=beta, log_dir=log_dir, comment=comment) if utils.is_main_training(): df = pd.DataFrame(np.vstack([log_lrs, values_matrix]).T, columns=[ "log_lr", "train_loss", "train_acc", "valid_loss", "valid_acc" ]) logger.info("Save lr finder values to {}.".format(save_file)) df.to_csv(save_file)
def step(self, training_point=None, valid_metric=None): if self.name == "warmR": if self.lr_decay_step > 0 and training_point[1]%self.lr_decay_step == 0: self.lr_scheduler.step(training_point[0]+training_point[1]/training_point[2]) elif self.lr_decay_step == 0: self.lr_scheduler.step(training_point[0]) elif self.name == "1cycle": self.lr_scheduler.step() elif self.name == "reduceP": # Sample a point in which the metrics of valid are computed and adjust learning rate at this point. if self.is_reduce_point(training_point): # Do not support horovod now. if utils.use_ddp(): # Multi-gpu case. # In this case, we do not compute valid set for all processes but just computing it in main process # and broadcast the metrics to other processes. if not self.init: device = utils.get_device_from_optimizer(self.lr_scheduler.optimizer) # Create a must tentor to prepare to broadcast with torch.distributed.broadcast fuction. self.metric = torch.randn(2, device=device) # New a group to broadcast the special metric tensor. It is important. self.group = torch.distributed.new_group(ranks=list(range(torch.distributed.get_world_size())), backend="nccl") self.init = True if utils.is_main_training(): # Gather the new value of metric. self.metric = torch.tensor([valid_metric[0], valid_metric[1]], device=self.metric.device) # Broadcast torch.distributed.broadcast(self.metric, 0, group=self.group) metric = self.metric[0] if self.metric == "valid_loss" else self.metric[1] else: # Single-GPU case. metric = valid_metric[0] if self.metric == "valid_loss" else valid_metric[1] self.lr_scheduler.step(metric)
def get_bunch_from_csv(self, trainset_csv: str, valid_csv: str = None, egs_params: dict = {}, data_loader_params_dict: dict = {}): Egs = ChunkEgs if "egs_type" in egs_params.keys(): egs_type = egs_params.pop("egs_type") if egs_type == "chunk": pass elif egs_type == "vector": Egs = VectorEgs else: raise TypeError( "Do not support {} egs now. Select one from [chunk, vector]." .format(egs_type)) trainset = Egs(trainset_csv, **egs_params) # For multi-GPU training. if not utils.is_main_training(): valid = None if valid_csv != "" and valid_csv is not None: valid = Egs(valid_csv) else: valid = None return self(trainset, valid, **data_loader_params_dict)
def lr_finder_compute(self, train_batch): model = self.elements["model"] if model.use_step: model.step(*self.training_point) loss, acc = self.train_one_batch(train_batch) model.backward_step(*self.training_point) if utils.is_main_training(): valid_loss, valid_acc = self.compute_validation(self.elements["data"].valid_loader) return ["train_loss", "train_acc", "valid_loss", "valid_acc"], [loss, acc, valid_loss, valid_acc]
def lr_finder_compute(self, train_batch): model = self.elements["model"] if model.use_step: model.step(*self.training_point) loss, acc = self.train_one_batch(train_batch) model.backward_step(*self.training_point) if utils.is_main_training(): valid_loss, valid_acc = self.compute_validation( self.elements["data"].valid_loader) weight = model.loss.weight.squeeze(dim=2) weight = F.normalize(weight, dim=1) orth = 0. for i in range(weight.shape[0]): for j in range(i + 1, weight.shape[0]): orth += torch.dot(weight[i], weight[j]).item() orth /= weight.shape[0] * (weight.shape[0] - 1) / 2 return ["train_loss", "train_acc", "valid_loss", "valid_acc", "orth"], [loss, acc, valid_loss, valid_acc, orth]
##--------------------------------------------------## ## ######################################################### START ######################################################### ## #### Set seed utils.set_all_seed(1024) ## #### Init environment # It is used for multi-gpu training if used (number of gpu-id > 1). # And it will do nothing for single-GPU training. utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution, args.port) ## #### Set sleep time for a rest # Use it to run a launcher with a countdown function when there are no extra GPU memory # but you really want to go to bed and know when the GPU memory will be free. if args.sleep > 0 and utils.is_main_training(): logger.info("This launcher will sleep {}s before starting...".format( args.sleep)) time.sleep(args.sleep) ## #### Auto-config params # If multi-GPU used, it will auto-scale learning rate by multiplying number of processes. optimizer_params["learn_rate"] = utils.auto_scale_lr( optimizer_params["learn_rate"]) # It is used for model.step() defined in model blueprint. if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]: model_params["step_params"]["T"] = (lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"]) ## #### Preprocess if stage <= 2 and endstage >= 0 and utils.is_main_training():
if args.sleep > 0: time.sleep(args.sleep) ## #### Init environment # It is used for multi-gpu training if used (number of gpu-id > 1). # And it will do nothing for single-GPU training. utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution, args.port) ## #### Auto-config params # If multi-GPU used, it will auto-scale learning rate by multiplying number of processes. optimizer_params["learn_rate"] = utils.auto_scale_lr(optimizer_params["learn_rate"]) # It is used for model.step() defined in model blueprint. if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]: model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"]) ## #### Preprocess if stage <= 2 and endstage >= 0 and utils.is_main_training(): # Here only give limited options because it is not convenient. # Suggest to pre-execute this shell script to make it freedom and then continue to run this launcher. kaldi_common.execute_command("sh subtools/pytorch/pipeline/preprocess_to_egs.sh " "--stage {stage} --endstage {endstage} --valid-split-type {valid_split_type} " "--nj {nj} --cmn {cmn} --limit-utts {limit_utts} --min-chunk {chunk_size} --overlap {overlap} " "--sample-type {sample_type} --chunk-num {chunk_num} --scale {scale} --force-clear {force_clear} " "--valid-num-utts {valid_utts} --valid-chunk-num {valid_chunk_num_every_utt} " "{traindata} {egs_dir}".format(stage=stage, endstage=endstage, valid_split_type=valid_split_type, nj=preprocess_nj, cmn=str(cmn).lower(), limit_utts=limit_utts, chunk_size=chunk_size, overlap=overlap, sample_type=sample_type, chunk_num=chunk_num, scale=scale, force_clear=str(force_clear).lower(), valid_utts=valid_utts, valid_chunk_num_every_utt=valid_chunk_num_every_utt, traindata=traindata, egs_dir=egs_dir)) #### Train model if stage <= 3 <= endstage:
def run(self): """Main function to start a training process. """ try: self.init_training() if utils.is_main_training(): self.reporter = Reporter(self) start_epoch = self.params["start_epoch"] epochs = self.params["epochs"] data = self.elements["data"] model = self.elements["model"] model_forward = self.elements[ "model_forward"] # See init_training. lr_scheduler = self.elements["lr_scheduler"] base_optimizer = self.elements["optimizer"] # For lookahead. if getattr(base_optimizer, "optimizer", None) is not None: base_optimizer = base_optimizer.optimizer last_lr = base_optimizer.state_dict()['param_groups'][0]['lr'] if utils.is_main_training(): logger.info("Training will run for {0} epochs.".format(epochs)) for this_epoch in range(start_epoch, epochs): # Set random seed w.r.t epoch for distributed training. if isinstance(data.train_loader.sampler, torch.utils.data.distributed.DistributedSampler) and \ self.params["ddp_random_epoch"]: data.train_loader.sampler.set_epoch(this_epoch) for this_iter, batch in enumerate(data.train_loader, 0): self.training_point = (this_epoch, this_iter, data.num_batch_train ) # It is important for reporter. if model.use_step: model.step(*self.training_point) loss, acc = self.train_one_batch(batch) model.backward_step(*self.training_point) # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau # and some simple schedulers whose step() parameter is 'epoch' only are supported. lr_scheduler_params = { "training_point": self.training_point } if utils.is_main_training( ) or lr_scheduler.name == "reduceP": if data.valid_loader and (self.reporter.is_report(self.training_point) or \ lr_scheduler.is_reduce_point(self.training_point)): valid_loss, valid_acc = self.compute_validation( data.valid_loader) # real_snapshot is set for tensorboard to avoid workspace problem real_snapshot = { "train_loss": loss, "valid_loss": valid_loss, "train_acc": acc * 100, "valid_acc": valid_acc * 100 } snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "{0:.6f}".format(valid_loss), "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "{0:.2f}".format(valid_acc * 100), "real": real_snapshot } try: weight = model.loss.weight.squeeze(dim=2) weight = F.normalize(weight, dim=1) orth_snapshot = {"orth_snp": 0.} for i in range(weight.shape[0]): for j in range(i + 1, weight.shape[0]): orth_snapshot["orth_snp"] += torch.dot( weight[i], weight[j]).item() orth_snapshot["orth_snp"] /= weight.shape[ 0] * (weight.shape[0] - 1) / 2 real_snapshot.update(orth_snapshot) snapshot.update(orth_snapshot) snapshot["real"] = real_snapshot except Exception as e: pass # For ReduceLROnPlateau. lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) else: real_snapshot = { "train_loss": loss, "train_acc": acc * 100 } snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "", "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "", "real": real_snapshot } if lr_scheduler is not None: # It is not convenient to wrap lr_scheduler (doing). if isinstance(lr_scheduler, LRSchedulerWrapper): lr_scheduler.step(**lr_scheduler_params) if lr_scheduler.name == "reduceP" and utils.is_main_training( ): current_lr = base_optimizer.state_dict( )['param_groups'][0]['lr'] if current_lr < last_lr: last_lr = current_lr self.save_model(from_epoch=False) else: # For some pytorch lr_schedulers, but it is not available for all. lr_scheduler.step(this_epoch) if utils.is_main_training(): self.reporter.update(snapshot) if utils.is_main_training(): self.save_model() if utils.is_main_training(): self.reporter.finish() except BaseException as e: if utils.use_ddp(): utils.cleanup_ddp() if not isinstance(e, KeyboardInterrupt): traceback.print_exc() sys.exit(1)
def run(self): """Main function to start a training process. """ try: self.init_training() if utils.is_main_training(): self.reporter = Reporter(self) start_epoch = self.params["start_epoch"] epochs = self.params["epochs"] data = self.elements["data"] model = self.elements["model"] model_forward = self.elements[ "model_forward"] # See init_training. lr_scheduler = self.elements["lr_scheduler"] if utils.is_main_training(): logger.info("Training will run for {0} epochs.".format(epochs)) for this_epoch in range(start_epoch, epochs): for this_iter, batch in enumerate(data.train_loader, 0): self.training_point = (this_epoch, this_iter, data.num_batch_train ) # It is important for reporter. if model.use_step: model.step(*self.training_point) loss, acc = self.train_one_batch(batch) # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau # and some simple schedulers whose step() parameter is 'epoch' only are supported. lr_scheduler_params = { "training_point": self.training_point } if utils.is_main_training( ) or lr_scheduler.name == "reduceP": if data.valid_loader and (self.reporter.is_report(self.training_point) or \ lr_scheduler.is_reduce_point(self.training_point)): valid_loss, valid_acc = self.compute_validation( data.valid_loader) snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "{0:.6f}".format(valid_loss), "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "{0:.2f}".format(valid_acc * 100) } # For ReduceLROnPlateau. lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) else: snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "", "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "" } if lr_scheduler is not None: # It is not convenient to wrap lr_scheduler (doing). if isinstance(lr_scheduler, LRSchedulerWrapper): lr_scheduler.step(**lr_scheduler_params) else: # For some pytorch lr_schedulers, but it is not available for all. lr_scheduler.step(this_epoch) if utils.is_main_training(): self.reporter.update(snapshot) if utils.is_main_training(): self.save_model() if utils.is_main_training(): self.reporter.finish() except BaseException as e: if utils.use_ddp(): utils.cleanup_ddp() if not isinstance(e, KeyboardInterrupt): traceback.print_exc() sys.exit(1)
def wrapper(self, trn_loader, optimizer, init_lr=1e-6, final_lr=10., num_iters=None, beta=0.98, split=[5, -10], log_dir=None, comment=None): if init_lr < 0: raise ValueError( "Expected init_lr > 0, but got init_lr = {}.".format(init_lr)) if final_lr < init_lr: raise ValueError( "Expected final_lr > init_lr, but got final_lr {} <= and init_lr {}." .format(final_lr, init_lr)) if num_iters is not None and num_iters <= 1: raise ValueError( "Expected num_iters > 1, but got {}.".format(num_iters)) if not isinstance(trn_loader, DataLoader): raise TypeError("Expected Dataloader, but got {}.".format( type(trn_loader).__name__)) if not isinstance(optimizer, Optimizer): raise TypeError("Expected Optimizer, but got {}.".format( type(Optimizer).__name__)) # If num_iters is None, then just run one epoch. if num_iters is not None: num_iters = num_iters epochs = (num_iters - 1) // len(trn_loader) + 1 else: num_iters = len(trn_loader) epochs = 1 logger.info( "Run lr finder from init_lr = {} to final_lr = {} with {} iters.". format(init_lr, final_lr, num_iters)) # Init. mult = (final_lr / init_lr)**(1 / (num_iters - 1)) num_batch = 0 avg_values = 0. log_lrs = [] if utils.is_main_training(): reporter = LRFinderReporter(num_iters, log_dir=log_dir, comment=comment) # Start. lr = init_lr optimizer.param_groups[0]['lr'] = lr for this_epoch in range(epochs): for batch in trn_loader: num_batch += 1 # The values is a vector of numpy and function return a list of float values. keys, values = function(self, batch) values = np.array(values) if not utils.is_main_training(): continue # Compute the smoothed values. The avg_values will be also a vector of numpy rather than 0. avg_values = beta * avg_values + (1 - beta) * values smoothed_values = avg_values / (1 - beta**num_batch) snapshot = {"lr": lr} for i in range(len(keys)): snapshot[keys[i]] = smoothed_values[i] reporter.update(num_batch, snapshot) # # Stop if the main value is exploding. # if num_batch > 1 and smoothed_values[0] > 4 * best_value: # reporter.finish() # logger.info("Stop lr finder early by default rule.") # return log_lrs[split[0]:split[1]], value_matrix.T[:,split[0]:split[1]] # Record the best main value. The main value which has the index-0 is usually the training loss. if num_batch == 1 or smoothed_values[0] < best_value: best_value = smoothed_values[0] # Store the values. if num_batch == 1: value_matrix = smoothed_values else: value_matrix = np.vstack([value_matrix, smoothed_values]) log_lrs.append(math.log10(lr)) if num_batch >= num_iters: reporter.finish() return log_lrs[ split[0]:split[1]], value_matrix.T[:, split[0]:split[1]] # Update the lr for the next step. lr *= mult optimizer.param_groups[0]['lr'] = lr if not utils.is_main_training(): return None, None reporter.finish() return log_lrs[split[0]:split[1]], value_matrix.T[:, split[0]:split[1]]
def __init__(self, trainset, valid=None, use_fast_loader=False, max_prefetch=10, batch_size=512, shuffle=True, num_workers=0, pin_memory=False, drop_last=True): num_samples = len(trainset) num_gpu = 1 multi_gpu = False if utils.use_horovod(): # Multi-GPU training. import horovod.torch as hvd # Partition dataset among workers using DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=shuffle) multi_gpu = True num_gpu = hvd.size() elif utils.use_ddp(): # The num_replicas/world_size and rank will be set automatically with DDP. train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, shuffle=shuffle) multi_gpu = True num_gpu = dist.get_world_size() else: train_sampler = None if multi_gpu: # If use DistributedSampler, the shuffle of DataLoader should be set False. shuffle = False if not utils.is_main_training(): valid = None if use_fast_loader: self.train_loader = DataLoaderFast(max_prefetch, trainset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, sampler=train_sampler) else: self.train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, sampler=train_sampler) self.num_batch_train = len(self.train_loader) if self.num_batch_train <= 0: raise ValueError( "Expected num_batch of trainset > 0. There are your egs info: num_gpu={}, num_samples/gpu={}, " "batch-size={}, drop_last={}.\nNote: If batch-size > num_samples/gpu and drop_last is true, then it " "will get 0 batch.".format(num_gpu, len(trainset) / num_gpu, batch_size, drop_last)) if valid is not None: valid_batch_size = min(batch_size, len(valid)) # To save GPU memory if len(valid) <= 0: raise ValueError("Expected num_samples of valid > 0.") # Do not use DataLoaderFast for valid for it increases the memory all the time when compute_valid_accuracy is True. # But I have not find the real reason. self.valid_loader = DataLoader(valid, batch_size=valid_batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) self.num_batch_valid = len(self.valid_loader) else: self.valid_loader = None self.num_batch_valid = 0
def run(self): """Main function to start a training process. """ try: self.init_training() if utils.is_main_training(): self.reporter = Reporter(self) start_epoch = self.params["start_epoch"] epochs = self.params["epochs"] data = self.elements["data"] model = self.elements["model"] model_forward = self.elements[ "model_forward"] # See init_training. lr_scheduler = self.elements["lr_scheduler"] base_optimizer = self.elements["optimizer"] best_valid_acc = 0.0 # For lookahead. if getattr(base_optimizer, "optimizer", None) is not None: base_optimizer = base_optimizer.optimizer last_lr = base_optimizer.state_dict()['param_groups'][0]['lr'] if utils.is_main_training(): logger.info("Training will run for {0} epochs.".format(epochs)) for this_epoch in range(start_epoch, epochs): if isinstance(data.train_loader.sampler, torch.utils.data.distributed.DistributedSampler): data.train_loader.sampler.set_epoch(this_epoch) for this_iter, batch in enumerate(data.train_loader, 0): self.training_point = (this_epoch, this_iter, data.num_batch_train ) # It is important for reporter. if model.use_step: model.step(*self.training_point) loss, acc = self.train_one_batch(batch) model.backward_step(*self.training_point) # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau # and some simple schedulers whose step() parameter is 'epoch' only are supported. lr_scheduler_params = { "training_point": self.training_point } valid_computed = False if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point( self.training_point): assert data.valid_loader is not None valid_loss, valid_acc = self.compute_validation( data.valid_loader) lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) valid_computed = True if utils.is_main_training(): if valid_computed or (data.valid_loader and self.reporter.is_report( self.training_point)): if not valid_computed: valid_loss, valid_acc = self.compute_validation( data.valid_loader) valid_computed = False # real_snapshot is set for tensorboard to avoid workspace problem real_snapshot = { "train_loss": loss, "valid_loss": valid_loss, "train_acc": acc * 100, "valid_acc": valid_acc * 100 } snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "{0:.6f}".format(valid_loss), "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "{0:.2f}".format(valid_acc * 100), "real": real_snapshot } # For ReduceLROnPlateau. lr_scheduler_params["valid_metric"] = (valid_loss, valid_acc) if lr_scheduler.name == "warmR": if this_epoch >= epochs - 1 and valid_acc >= best_valid_acc: best_valid_acc = valid_acc self.save_model(from_epoch=False) else: real_snapshot = { "train_loss": loss, "train_acc": acc * 100 } snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "", "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "", "real": real_snapshot } if lr_scheduler is not None: # It is not convenient to wrap lr_scheduler (doing). if isinstance(lr_scheduler, LRSchedulerWrapper): lr_scheduler.step(**lr_scheduler_params) if utils.is_main_training(): current_lr = base_optimizer.state_dict( )['param_groups'][0]['lr'] if lr_scheduler.name == "reduceP": if current_lr < last_lr: last_lr = current_lr self.save_model(from_epoch=False) elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point( self.training_point): self.save_model(from_epoch=False) elif lr_scheduler.name == "cyclic" and utils.is_main_training( ): cyclic_size = lr_scheduler.lr_scheduler.total_size current_iter = self.training_point[ 0] * self.training_point[ 2] + self.training_point[1] + 1 if current_iter % cyclic_size == 0 and current_iter != 1: self.save_model(from_epoch=False) else: # For some pytorch lr_schedulers, but it is not available for all. lr_scheduler.step(this_epoch) if utils.is_main_training(): self.reporter.update(snapshot) if utils.is_main_training(): if epochs >= 20: if this_epoch >= epochs - 10: print(current_lr) self.save_model() else: print(current_lr) self.save_model() if utils.is_main_training(): self.reporter.finish() except BaseException as e: if utils.use_ddp(): utils.cleanup_ddp() if not isinstance(e, KeyboardInterrupt): traceback.print_exc() sys.exit(1)
def run(self): """Main function to start a training process. """ try: self.init_training() if utils.is_main_training(): self.reporter = Reporter(self) start_epoch = self.params["start_epoch"] epochs = self.params["epochs"] data = self.elements["data"] model = self.elements["model"] model_forward = self.elements[ "model_forward"] # See init_training. lr_scheduler = self.elements["lr_scheduler"] if utils.is_main_training(): logger.info("Training will run for {0} epochs.".format(epochs)) for this_epoch in range(start_epoch, epochs): for this_iter, batch in enumerate(data.train_loader, 0): self.training_point = (this_epoch, this_iter, data.num_batch_train ) # It is important for reporter. if model.use_step: model.step(*self.training_point) if lr_scheduler is not None: # It is not convenient to wrap lr_scheduler (doing). if isinstance(lr_scheduler, LRSchedulerWrapper): lr_scheduler.step(self.training_point) else: # For some pytorch lr_schedulers, but it is not available for all. lr_scheduler.step(this_epoch) loss, acc = self.train_one_batch(batch) # For multi-GPU training. if utils.is_main_training(): if data.valid_loader and self.reporter.is_report( self.training_point): valid_loss, valid_acc = self.compute_validation( data.valid_loader) snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "{0:.6f}".format(valid_loss), "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "{0:.2f}".format(valid_acc * 100) } else: snapshot = { "train_loss": "{0:.6f}".format(loss), "valid_loss": "", "train_acc": "{0:.2f}".format(acc * 100), "valid_acc": "" } if utils.is_main_training(): self.reporter.update(snapshot) if utils.is_main_training(): self.save_model() if utils.is_main_training(): self.reporter.finish() except BaseException as e: if utils.use_ddp(): utils.cleanup_ddp() if not isinstance(e, KeyboardInterrupt): traceback.print_exc() sys.exit(1)