def print_error_for_batch(self, cnt_idx, seq_idx, epoch_idx): try: t_1 = self.loss_mat[epoch_idx, seq_idx] t_2 = self.time_mat[epoch_idx, seq_idx] mes = "{}, ".format(self.seq_names[seq_idx]) mes += "{:d}/{:d}, ".format(cnt_idx+1, \ self.seq_num) mes += "Time: {:.6f}s, Loss: {:.6f}".format(t_2, t_1) nii_display.f_eprint(mes, flush=True) except IndexError: nii_display.f_die("Unknown sample index in Monitor") except KeyError: nii_display.f_die("Unknown sample index in Monitor") return
def f_train_wrapper(args, pt_model, loss_wrapper, device, \ optimizer_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint = None): """ f_train_wrapper(args, pt_model, loss_wrapper, device, optimizer_wrapper train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") optimizer_wrapper: a wrapper over optimizer (defined in op_manager.py) optimizer_wrapper.optimizer is torch.optimizer train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. check_point: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") ############## ## Preparation ############## # get the optimizer optimizer_wrapper.print_info() optimizer = optimizer_wrapper.optimizer lr_scheduler = optimizer_wrapper.lr_scheduler epoch_num = optimizer_wrapper.get_epoch_num() no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' # prepare for DataParallism if available # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: flag_multi_device = True nii_display.f_print("\nUse %d GPUs\n" % (torch.cuda.device_count())) # no way to call normtarget_f after pt_model is in DataParallel normtarget_f = pt_model.normalize_target pt_model = nn.DataParallel(pt_model) else: nii_display.f_print("\nUse single GPU: %s\n" % \ (torch.cuda.get_device_name(device))) flag_multi_device = False normtarget_f = None pt_model.to(device, dtype=nii_dconf.d_dtype) # print the network nii_nn_tools.f_model_show(pt_model) nii_nn_tools.f_loss_show(loss_wrapper) ############################### ## Resume training if necessary ############################### # resume training or initialize the model if necessary cp_names = nii_nn_manage_conf.CheckPointKey() if checkpoint is not None: if type(checkpoint) is dict: # checkpoint # load model parameter and optimizer state if cp_names.state_dict in checkpoint: # wrap the state_dic in f_state_dict_wrapper # in case the model is saved when DataParallel is on pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper( checkpoint[cp_names.state_dict], flag_multi_device)) # load optimizer state if cp_names.optimizer in checkpoint and \ not args.ignore_optimizer_statistics_in_trained_model: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) # optionally, load training history if not args.ignore_training_history_in_trained_model: #nii_display.f_print("Load ") if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] if cp_names.lr_scheduler in checkpoint and \ checkpoint[cp_names.lr_scheduler] and lr_scheduler.f_valid(): lr_scheduler.f_load_state_dict( checkpoint[cp_names.lr_scheduler]) nii_display.f_print("Load check point, resume training") else: nii_display.f_print("Load pretrained model and optimizer") else: # only model status pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper(checkpoint, flag_multi_device)) nii_display.f_print("Load pretrained model") ###################### ### User defined setup ###################### if hasattr(pt_model, "other_setups"): nii_display.f_print("Conduct User-defined setup") pt_model.other_setups() # This should be merged with other_setups if hasattr(pt_model, "g_pretrained_model_path") and \ hasattr(pt_model, "g_pretrained_model_prefix"): nii_display.f_print("Load pret-rained models as part of this mode") nii_nn_tools.f_load_pretrained_model_partially( pt_model, pt_model.g_pretrained_model_path, pt_model.g_pretrained_model_prefix) ###################### ### Start training ###################### # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model.train() # set validation flag if necessary if hasattr(pt_model, 'validation'): pt_model.validation = False mes = "Warning: model.validation is deprecated, " mes += "please use model.flag_validation" nii_display.f_print(mes, 'warning') if hasattr(pt_model, 'flag_validation'): pt_model.flag_validation = False f_run_one_epoch(args, pt_model, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer, normtarget_f) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model.eval() # set validation flag if necessary if hasattr(pt_model, 'validation'): pt_model.validation = True mes = "Warning: model.validation is deprecated, " mes += "please use model.flag_validation" nii_display.f_print(mes, 'warning') if hasattr(pt_model, 'flag_validation'): pt_model.flag_validation = True with torch.no_grad(): f_run_one_epoch(args, pt_model, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None, normtarget_f) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) # update lr rate scheduler if necessary if lr_scheduler.f_valid(): lr_scheduler.f_step(loss_val) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info( epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best, optimizer_wrapper.get_lr_info()) # save the best model if flag_new_best: tmp_best_name = nii_nn_tools.f_save_trained_name(args) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: tmp_model_name = nii_nn_tools.f_save_epoch_name(args, epoch_idx) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None if lr_scheduler.f_valid(): lr_scheduler_state = lr_scheduler.f_state_dict() else: lr_scheduler_state = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log, cp_names.lr_scheduler: lr_scheduler_state } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # Early stopping # note: if LR scheduler is used, early stopping will be # disabled if lr_scheduler.f_allow_early_stopping() and \ monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') nii_display.f_print("{}".format(nii_nn_tools.f_save_trained_name(args))) return
def f_train_wrapper(args, pt_model, loss_wrapper, device, \ optimizer_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint = None): """ f_train_wrapper(args, pt_model, loss_wrapper, device, optimizer_wrapper train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") optimizer_wrapper: a wrapper over optimizer (defined in op_manager.py) optimizer_wrapper.optimizer is torch.optimizer train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. check_point: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") # get the optimizer optimizer_wrapper.print_info() optimizer = optimizer_wrapper.optimizer epoch_num = optimizer_wrapper.get_epoch_num() no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' # print the network pt_model.to(device, dtype=nii_dconf.d_dtype) f_model_show(pt_model) # resume training or initialize the model if necessary cp_names = CheckPointKey() if checkpoint is not None: if type(checkpoint) is dict: # checkpoint if cp_names.state_dict in checkpoint: pt_model.load_state_dict(checkpoint[cp_names.state_dict]) if cp_names.optimizer in checkpoint: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] nii_display.f_print("Load check point and resume training") else: # only model status pt_model.load_state_dict(checkpoint) nii_display.f_print("Load pre-trained model") # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model.train() f_run_one_epoch(args, pt_model, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model.eval() with torch.no_grad(): f_run_one_epoch(args, pt_model, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info(epoch_idx, \ time_trn, \ loss_trn, \ time_val, \ loss_val, \ flag_new_best) # save the best model if flag_new_best: tmp_best_name = f_save_trained_name(args) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: tmp_model_name = f_save_epoch_name(args, epoch_idx) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # early stopping if monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') nii_display.f_print("{}".format(f_save_trained_name(args))) return
def f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, \ optimizer_G_wrapper, optimizer_D_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint_G = None, checkpoint_D = None): """ f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, optimizer_G_wrapper, optimizer_D_wrapper, train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model_G: generator, pytorch model (torch.nn.Module) pt_model_D: discriminator, pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss functions loss_wrapper.compute_D_real(discriminator_output) loss_wrapper.compute_D_fake(discriminator_output) loss_wrapper.compute_G(discriminator_output) loss_wrapper.compute_G(fake, real) device: torch.device("cuda") or torch.device("cpu") optimizer_G_wrapper: a optimizer wrapper for generator (defined in op_manager.py) optimizer_D_wrapper: a optimizer wrapper for discriminator (defined in op_manager.py) train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. checkpoint_G: a check_point that stores every thing to resume training checkpoint_D: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") # get the optimizer optimizer_G_wrapper.print_info() optimizer_D_wrapper.print_info() optimizer_G = optimizer_G_wrapper.optimizer optimizer_D = optimizer_D_wrapper.optimizer epoch_num = optimizer_G_wrapper.get_epoch_num() no_best_epoch_num = optimizer_G_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' model_tags = ["_G", "_D"] # prepare for DataParallism if available # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: nii_display.f_die("data_parallel not implemented for GAN") else: nii_display.f_print("Use single GPU: %s" % \ (torch.cuda.get_device_name(device))) flag_multi_device = False normtarget_f = None pt_model_G.to(device, dtype=nii_dconf.d_dtype) pt_model_D.to(device, dtype=nii_dconf.d_dtype) # print the network nii_display.f_print("Setup generator") f_model_show(pt_model_G) nii_display.f_print("Setup discriminator") f_model_show(pt_model_D) # resume training or initialize the model if necessary cp_names = CheckPointKey() if checkpoint_G is not None or checkpoint_D is not None: for checkpoint, optimizer, pt_model, model_name in \ zip([checkpoint_G, checkpoint_D], [optimizer_G, optimizer_D], [pt_model_G, pt_model_D], ["Generator", "Discriminator"]): nii_display.f_print("For %s" % (model_name)) if type(checkpoint) is dict: # checkpoint # load model parameter and optimizer state if cp_names.state_dict in checkpoint: # wrap the state_dic in f_state_dict_wrapper # in case the model is saved when DataParallel is on pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper( checkpoint[cp_names.state_dict], flag_multi_device)) # load optimizer state if cp_names.optimizer in checkpoint: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) # optionally, load training history if not args.ignore_training_history_in_trained_model: #nii_display.f_print("Load ") if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] nii_display.f_print("Load check point, resume training") else: nii_display.f_print("Load pretrained model and optimizer") elif checkpoint is not None: # only model status #pt_model.load_state_dict(checkpoint) pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper(checkpoint, flag_multi_device)) nii_display.f_print("Load pretrained model") else: nii_display.f_print("No pretrained model") # done for resume training # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() if hasattr(loss_wrapper, "flag_wgan") and loss_wrapper.flag_wgan: f_wrapper_gan_one_epoch = f_run_one_epoch_WGAN else: f_wrapper_gan_one_epoch = f_run_one_epoch_GAN # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model_D.train() pt_model_G.train() f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer_G, optimizer_D, normtarget_f) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model_G.eval() pt_model_D.eval() with torch.no_grad(): f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None, None, normtarget_f) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info( epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best) # save the best model if flag_new_best: for pt_model, model_tag in \ zip([pt_model_G, pt_model_D], model_tags): tmp_best_name = f_save_trained_name_GAN(args, model_tag) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: # save model discrminator and generator for pt_model, optimizer, model_tag in \ zip([pt_model_G, pt_model_D], [optimizer_G, optimizer_D], model_tags): tmp_model_name = f_save_epoch_name_GAN(args, epoch_idx, model_tag) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # early stopping if monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') for model_tag in model_tags: nii_display.f_print("{}".format( f_save_trained_name_GAN(args, model_tag))) return