def f_train_wrapper(args, pt_model, loss_wrapper, device, \ optimizer_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint = None): """ f_train_wrapper(args, pt_model, loss_wrapper, device, optimizer_wrapper train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") optimizer_wrapper: a wrapper over optimizer (defined in op_manager.py) optimizer_wrapper.optimizer is torch.optimizer train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. check_point: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") ############## ## Preparation ############## # get the optimizer optimizer_wrapper.print_info() optimizer = optimizer_wrapper.optimizer lr_scheduler = optimizer_wrapper.lr_scheduler epoch_num = optimizer_wrapper.get_epoch_num() no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' # prepare for DataParallism if available # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: flag_multi_device = True nii_display.f_print("\nUse %d GPUs\n" % (torch.cuda.device_count())) # no way to call normtarget_f after pt_model is in DataParallel normtarget_f = pt_model.normalize_target pt_model = nn.DataParallel(pt_model) else: nii_display.f_print("\nUse single GPU: %s\n" % \ (torch.cuda.get_device_name(device))) flag_multi_device = False normtarget_f = None pt_model.to(device, dtype=nii_dconf.d_dtype) # print the network nii_nn_tools.f_model_show(pt_model) nii_nn_tools.f_loss_show(loss_wrapper) ############################### ## Resume training if necessary ############################### # resume training or initialize the model if necessary cp_names = nii_nn_manage_conf.CheckPointKey() if checkpoint is not None: if type(checkpoint) is dict: # checkpoint # load model parameter and optimizer state if cp_names.state_dict in checkpoint: # wrap the state_dic in f_state_dict_wrapper # in case the model is saved when DataParallel is on pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper( checkpoint[cp_names.state_dict], flag_multi_device)) # load optimizer state if cp_names.optimizer in checkpoint and \ not args.ignore_optimizer_statistics_in_trained_model: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) # optionally, load training history if not args.ignore_training_history_in_trained_model: #nii_display.f_print("Load ") if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] if cp_names.lr_scheduler in checkpoint and \ checkpoint[cp_names.lr_scheduler] and lr_scheduler.f_valid(): lr_scheduler.f_load_state_dict( checkpoint[cp_names.lr_scheduler]) nii_display.f_print("Load check point, resume training") else: nii_display.f_print("Load pretrained model and optimizer") else: # only model status pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper(checkpoint, flag_multi_device)) nii_display.f_print("Load pretrained model") ###################### ### User defined setup ###################### if hasattr(pt_model, "other_setups"): nii_display.f_print("Conduct User-defined setup") pt_model.other_setups() # This should be merged with other_setups if hasattr(pt_model, "g_pretrained_model_path") and \ hasattr(pt_model, "g_pretrained_model_prefix"): nii_display.f_print("Load pret-rained models as part of this mode") nii_nn_tools.f_load_pretrained_model_partially( pt_model, pt_model.g_pretrained_model_path, pt_model.g_pretrained_model_prefix) ###################### ### Start training ###################### # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model.train() # set validation flag if necessary if hasattr(pt_model, 'validation'): pt_model.validation = False mes = "Warning: model.validation is deprecated, " mes += "please use model.flag_validation" nii_display.f_print(mes, 'warning') if hasattr(pt_model, 'flag_validation'): pt_model.flag_validation = False f_run_one_epoch(args, pt_model, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer, normtarget_f) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model.eval() # set validation flag if necessary if hasattr(pt_model, 'validation'): pt_model.validation = True mes = "Warning: model.validation is deprecated, " mes += "please use model.flag_validation" nii_display.f_print(mes, 'warning') if hasattr(pt_model, 'flag_validation'): pt_model.flag_validation = True with torch.no_grad(): f_run_one_epoch(args, pt_model, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None, normtarget_f) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) # update lr rate scheduler if necessary if lr_scheduler.f_valid(): lr_scheduler.f_step(loss_val) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info( epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best, optimizer_wrapper.get_lr_info()) # save the best model if flag_new_best: tmp_best_name = nii_nn_tools.f_save_trained_name(args) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: tmp_model_name = nii_nn_tools.f_save_epoch_name(args, epoch_idx) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None if lr_scheduler.f_valid(): lr_scheduler_state = lr_scheduler.f_state_dict() else: lr_scheduler_state = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log, cp_names.lr_scheduler: lr_scheduler_state } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # Early stopping # note: if LR scheduler is used, early stopping will be # disabled if lr_scheduler.f_allow_early_stopping() and \ monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') nii_display.f_print("{}".format(nii_nn_tools.f_save_trained_name(args))) return
def main(): """ main(): the default wrapper for training and inference process Please prepare config.py and model.py """ # arguments initialization args = nii_arg_parse.f_args_parsed() # nii_warn.f_print_w_date("Start program", level='h') nii_warn.f_print("Load module: %s" % (args.module_config)) nii_warn.f_print("Load module: %s" % (args.module_model)) prj_conf = importlib.import_module(args.module_config) prj_model = importlib.import_module(args.module_model) # initialization nii_startup.set_random_seed(args.seed, args) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # prepare data io if not args.inference: params = { 'batch_size': args.batch_size, 'shuffle': args.shuffle, 'num_workers': args.num_workers } # Load file list and create data loader trn_lst = nii_list_tool.read_list_from_text(prj_conf.trn_list) trn_set = nii_dset.NIIDataSetLoader( prj_conf.trn_set_name, \ trn_lst, prj_conf.input_dirs, \ prj_conf.input_exts, \ prj_conf.input_dims, \ prj_conf.input_reso, \ prj_conf.input_norm, \ prj_conf.output_dirs, \ prj_conf.output_exts, \ prj_conf.output_dims, \ prj_conf.output_reso, \ prj_conf.output_norm, \ './', params = params, truncate_seq = prj_conf.truncate_seq, min_seq_len = prj_conf.minimum_len, save_mean_std = True, wav_samp_rate = prj_conf.wav_samp_rate) if prj_conf.val_list is not None: val_lst = nii_list_tool.read_list_from_text(prj_conf.val_list) val_set = nii_dset.NIIDataSetLoader( prj_conf.val_set_name, val_lst, prj_conf.input_dirs, \ prj_conf.input_exts, \ prj_conf.input_dims, \ prj_conf.input_reso, \ prj_conf.input_norm, \ prj_conf.output_dirs, \ prj_conf.output_exts, \ prj_conf.output_dims, \ prj_conf.output_reso, \ prj_conf.output_norm, \ './', \ params = params, truncate_seq= prj_conf.truncate_seq, min_seq_len = prj_conf.minimum_len, save_mean_std = False, wav_samp_rate = prj_conf.wav_samp_rate) else: val_set = None # initialize the model and loss function model = prj_model.Model(trn_set.get_in_dim(), \ trn_set.get_out_dim(), \ args, trn_set.get_data_mean_std()) loss_wrapper = prj_model.Loss(args) # initialize the optimizer optimizer_wrapper = nii_op_wrapper.OptimizerWrapper(model, args) # if necessary, resume training if args.trained_model == "": checkpoint = None else: checkpoint = torch.load(args.trained_model) # start training nii_nn_wrapper.f_train_wrapper(args, model, loss_wrapper, device, optimizer_wrapper, trn_set, val_set, checkpoint) # done for traing else: # for inference # default, no truncating, no shuffling params = { 'batch_size': args.batch_size, 'shuffle': False, 'num_workers': args.num_workers } if type(prj_conf.test_list) is list: t_lst = prj_conf.test_list else: t_lst = nii_list_tool.read_list_from_text(prj_conf.test_list) test_set = nii_dset.NIIDataSetLoader( prj_conf.test_set_name, \ t_lst, \ prj_conf.test_input_dirs, prj_conf.input_exts, prj_conf.input_dims, prj_conf.input_reso, prj_conf.input_norm, prj_conf.test_output_dirs, prj_conf.output_exts, prj_conf.output_dims, prj_conf.output_reso, prj_conf.output_norm, './', params = params, truncate_seq= None, min_seq_len = None, save_mean_std = False, wav_samp_rate = prj_conf.wav_samp_rate) # initialize model model = prj_model.Model(test_set.get_in_dim(), \ test_set.get_out_dim(), \ args) if args.trained_model == "": print("No model is loaded by ---trained-model for inference") print("By default, load %s%s" % (args.save_trained_name, args.save_model_ext)) checkpoint = torch.load( "%s%s" % (args.save_trained_name, args.save_model_ext)) else: checkpoint = torch.load(args.trained_model) # do inference and output data nii_nn_wrapper.f_inference_wrapper(args, model, device, \ test_set, checkpoint) # done return
def f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, \ optimizer_G_wrapper, optimizer_D_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint_G = None, checkpoint_D = None): """ f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, optimizer_G_wrapper, optimizer_D_wrapper, train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model_G: generator, pytorch model (torch.nn.Module) pt_model_D: discriminator, pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss functions loss_wrapper.compute_D_real(discriminator_output) loss_wrapper.compute_D_fake(discriminator_output) loss_wrapper.compute_G(discriminator_output) loss_wrapper.compute_G(fake, real) device: torch.device("cuda") or torch.device("cpu") optimizer_G_wrapper: a optimizer wrapper for generator (defined in op_manager.py) optimizer_D_wrapper: a optimizer wrapper for discriminator (defined in op_manager.py) train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. checkpoint_G: a check_point that stores every thing to resume training checkpoint_D: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") # get the optimizer optimizer_G_wrapper.print_info() optimizer_D_wrapper.print_info() optimizer_G = optimizer_G_wrapper.optimizer optimizer_D = optimizer_D_wrapper.optimizer epoch_num = optimizer_G_wrapper.get_epoch_num() no_best_epoch_num = optimizer_G_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' model_tags = ["_G", "_D"] # prepare for DataParallism if available # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: nii_display.f_die("data_parallel not implemented for GAN") else: nii_display.f_print("Use single GPU: %s" % \ (torch.cuda.get_device_name(device))) flag_multi_device = False normtarget_f = None pt_model_G.to(device, dtype=nii_dconf.d_dtype) pt_model_D.to(device, dtype=nii_dconf.d_dtype) # print the network nii_display.f_print("Setup generator") f_model_show(pt_model_G) nii_display.f_print("Setup discriminator") f_model_show(pt_model_D) # resume training or initialize the model if necessary cp_names = CheckPointKey() if checkpoint_G is not None or checkpoint_D is not None: for checkpoint, optimizer, pt_model, model_name in \ zip([checkpoint_G, checkpoint_D], [optimizer_G, optimizer_D], [pt_model_G, pt_model_D], ["Generator", "Discriminator"]): nii_display.f_print("For %s" % (model_name)) if type(checkpoint) is dict: # checkpoint # load model parameter and optimizer state if cp_names.state_dict in checkpoint: # wrap the state_dic in f_state_dict_wrapper # in case the model is saved when DataParallel is on pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper( checkpoint[cp_names.state_dict], flag_multi_device)) # load optimizer state if cp_names.optimizer in checkpoint: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) # optionally, load training history if not args.ignore_training_history_in_trained_model: #nii_display.f_print("Load ") if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] nii_display.f_print("Load check point, resume training") else: nii_display.f_print("Load pretrained model and optimizer") elif checkpoint is not None: # only model status #pt_model.load_state_dict(checkpoint) pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper(checkpoint, flag_multi_device)) nii_display.f_print("Load pretrained model") else: nii_display.f_print("No pretrained model") # done for resume training # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() if hasattr(loss_wrapper, "flag_wgan") and loss_wrapper.flag_wgan: f_wrapper_gan_one_epoch = f_run_one_epoch_WGAN else: f_wrapper_gan_one_epoch = f_run_one_epoch_GAN # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model_D.train() pt_model_G.train() f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer_G, optimizer_D, normtarget_f) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model_G.eval() pt_model_D.eval() with torch.no_grad(): f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None, None, normtarget_f) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info( epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best) # save the best model if flag_new_best: for pt_model, model_tag in \ zip([pt_model_G, pt_model_D], model_tags): tmp_best_name = f_save_trained_name_GAN(args, model_tag) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: # save model discrminator and generator for pt_model, optimizer, model_tag in \ zip([pt_model_G, pt_model_D], [optimizer_G, optimizer_D], model_tags): tmp_model_name = f_save_epoch_name_GAN(args, epoch_idx, model_tag) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # early stopping if monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') for model_tag in model_tags: nii_display.f_print("{}".format( f_save_trained_name_GAN(args, model_tag))) return
def f_train_wrapper(args, pt_model, loss_wrapper, device, \ optimizer_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint = None): """ f_train_wrapper(args, pt_model, loss_wrapper, device, optimizer_wrapper train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") optimizer_wrapper: a wrapper over optimizer (defined in op_manager.py) optimizer_wrapper.optimizer is torch.optimizer train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. check_point: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") # get the optimizer optimizer_wrapper.print_info() optimizer = optimizer_wrapper.optimizer epoch_num = optimizer_wrapper.get_epoch_num() no_best_epoch_num = optimizer_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' # print the network pt_model.to(device, dtype=nii_dconf.d_dtype) f_model_show(pt_model) # resume training or initialize the model if necessary cp_names = CheckPointKey() if checkpoint is not None: if type(checkpoint) is dict: # checkpoint if cp_names.state_dict in checkpoint: pt_model.load_state_dict(checkpoint[cp_names.state_dict]) if cp_names.optimizer in checkpoint: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] nii_display.f_print("Load check point and resume training") else: # only model status pt_model.load_state_dict(checkpoint) nii_display.f_print("Load pre-trained model") # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model.train() f_run_one_epoch(args, pt_model, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model.eval() with torch.no_grad(): f_run_one_epoch(args, pt_model, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info(epoch_idx, \ time_trn, \ loss_trn, \ time_val, \ loss_val, \ flag_new_best) # save the best model if flag_new_best: tmp_best_name = f_save_trained_name(args) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: tmp_model_name = f_save_epoch_name(args, epoch_idx) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # early stopping if monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') nii_display.f_print("{}".format(f_save_trained_name(args))) return
def __init__(self, dataset_name, \ file_list, \ input_dirs, input_exts, input_dims, input_reso, \ input_norm, \ output_dirs, output_exts, output_dims, output_reso, \ output_norm, \ stats_path, \ data_format = '<f4', \ params = None, \ truncate_seq = None, \ min_seq_len = None, save_mean_std = True, \ wav_samp_rate = None): """ NIIDataSetLoader( data_set_name, file_list, input_dirs, input_exts, input_dims, input_reso, input_norm, output_dirs, output_exts, output_dims, output_reso, output_norm, stats_path, data_format = '<f4', params = None, truncate_seq = None): Args: data_set_name: a string to name this dataset this will be used to name the statistics files such as the mean/std for this dataset file_list: a list of file name strings (without extension) input_dirs: a list of dirs from each input feature is loaded input_exts: a list of input feature name extentions input_dims: a list of input feature dimensions input_reso: a list of input feature temporal resolution, or None output_dirs: a list of dirs from each output feature is loaded output_exts: a list of output feature name extentions output_dims: a list of output feature dimensions output_reso: a list of output feature temporal resolution, or None stats_path: path to the directory of statistics(mean/std) data_format: method to load the data '<f4' (default): load data as float32m little-endian 'htk': load data as htk format params: parameter for torch.utils.data.DataLoader truncate_seq: None or int, truncate data sequence into smaller truncks truncate_seq > 0 specifies the trunck length Methods: get_loader(): return a torch.util.data.DataLoader get_dataset(): return a torch.util.data.DataSet """ nii_warn.f_print_w_date("Loading dataset %s" % (dataset_name), level="h") # create torch.util.data.DataSet self.m_dataset = NIIDataSet(dataset_name, \ file_list, \ input_dirs, input_exts, \ input_dims, input_reso, \ input_norm, \ output_dirs, output_exts, \ output_dims, output_reso, \ output_norm, \ stats_path, data_format, \ truncate_seq, min_seq_len,\ save_mean_std, \ wav_samp_rate) # create torch.util.data.DataLoader if params is None: tmp_params = nii_dconf.default_loader_conf else: tmp_params = params self.m_loader = torch.utils.data.DataLoader(self.m_dataset, **tmp_params) # done return