def load_state_dic(self, state_dic): """ resume training, load the information """ try: if self.seq_num != state_dic['seq_num']: nii_display.f_print("Number of samples are different \ from previous training", 'error') nii_display.f_print("Please make sure that you are \ using the same training/development sets as before.", "error") nii_display.f_print("Or\nPlease add --") nii_display.f_print("ignore_training_history_in_trained_model") nii_display.f_die(" to avoid loading training history") if self.epoch_num == state_dic['epoch_num']: self.loss_mat = state_dic['loss_mat'] self.time_mat = state_dic['time_mat'] else: # if training epoch is increased, resize the shape tmp_loss_mat = state_dic['loss_mat'] self.loss_mat = np.resize( self.loss_mat, [self.epoch_num, self.seq_num, tmp_loss_mat.shape[2]]) self.loss_mat[0:tmp_loss_mat.shape[0]] = tmp_loss_mat self.time_mat[0:tmp_loss_mat.shape[0]] = state_dic['time_mat'] self.seq_num = state_dic['seq_num'] # since the saved cur_epoch has been finished self.cur_epoch = state_dic['cur_epoch'] + 1 self.best_error = state_dic['best_error'] self.best_epoch = state_dic['best_epoch'] self.loss_flag = state_dic['loss_flag'] self.seq_names = {} except KeyError: nii_display.f_die("Invalid op_process_monitor state_dic")
def text2code(text, flag_lang='EN'): """ Convert text string into code indices input ----- text: string flag_lang: string, 'EN': English output ------ code_seq: list of integers """ code_seq = [] # parse the curly bracket text_trunks = toolkit_all.parse_curly_bracket(text) # parse if flag_lang == 'EN': # English text for text_trunk in text_trunks: code_seq += toolkit_en.text2code(text_trunk) else: # unsupporte languages nii_warn.f_die("Error: text2code cannot handle {:s}".format(flag_lang)) # convert to numpy format code_seq = np.array(code_seq, dtype=nii_dconf.h_dtype) return code_seq
def __init__(self, model, args): """ Initialize an optimizer over model.parameters() """ # check valildity of model if not hasattr(model, "parameters"): nii_warn.f_print("model is not torch.nn", "error") nii_warn.f_die("Error in creating OptimizerWrapper") # set optimizer type self.op_flag = args.optimizer self.lr = args.lr # create optimizer if self.op_flag == "Adam": self.optimizer = torch_optim.Adam(model.parameters(), lr=self.lr) elif self.op_flag == "RMSprop": self.optimizer = torch_optim.RMSprop(model.parameters(), lr=self.lr) else: nii_warn.f_print("%s not availabel" % (self.op_flag), "error") nii_warn.f_die("Please change optimizer") # number of epochs self.epochs = args.epochs self.no_best_epochs = args.no_best_epochs return
def _get_loss_for_learning_stopping(self, epoch_idx): # compute the average loss values if epoch_idx > self.cur_epoch: nii_display.f_print("To find loss for future epochs", 'error') nii_display.f_die("Op_process_monitor: error") if epoch_idx < 0: nii_display.f_print("To find loss for NULL epoch", 'error') nii_display.f_die("Op_process_monitor: error") loss_this = np.sum(self.loss_mat[epoch_idx, :, :], axis=0) # compute only part of the loss for early stopping when necessary loss_this = np.sum(loss_this * self.loss_flag) return loss_this
def print_error_for_batch(self, cnt_idx, seq_idx, epoch_idx): try: t_1 = self.loss_mat[epoch_idx, seq_idx] t_2 = self.time_mat[epoch_idx, seq_idx] mes = "{}, ".format(self.seq_names[seq_idx]) mes += "{:d}/{:d}, ".format(cnt_idx+1, \ self.seq_num) mes += "Time: {:.6f}s, Loss: {:.6f}".format(t_2, t_1) nii_display.f_eprint(mes, flush=True) except IndexError: nii_display.f_die("Unknown sample index in Monitor") except KeyError: nii_display.f_die("Unknown sample index in Monitor") return
def f_check_file_list(self): """ f_check_file_list(): Check the file list after initialization Make sure that the file in file_list appears in every input/output feature directory. If not, get a file_list in which every file is avaiable in every input/output directory """ if not isinstance(self.m_file_list, list): nii_warn.f_print("Read file list from directories") self.m_list = None # get a initial file list if self.m_file_list is None: self.m_file_list = nii_list_tools.listdir_with_ext( self.m_input_dirs[0], self.m_input_exts[0]) # check the list of files exist in all input/output directories for tmp_d, tmp_e in zip(self.m_input_dirs[1:], \ self.m_input_exts[1:]): tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e) self.m_file_list = nii_list_tools.common_members( tmp_list, self.m_file_list) if len(self.m_file_list) < 1: nii_warn.f_print("No input features after scannning", 'error') nii_warn.f_print("Please check input config", 'error') nii_warn.f_print("Please check feature directory", 'error') # check output files if necessary if self.m_output_dirs: for tmp_d, tmp_e in zip(self.m_output_dirs, \ self.m_output_exts): tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e) self.m_file_list = nii_list_tools.common_members( tmp_list, self.m_file_list) if len(self.m_file_list) < 1: nii_warn.f_print("No output data found", 'error') nii_warn.f_die("Please check outpupt config") else: #nii_warn.f_print("Not loading output features") pass # done return
def symbol_num(flag_lang='EN'): """ Return the number of symbols defined for one language input ----- flag_lange: string, 'EN': English output ------ integer """ if flag_lang == 'EN': return toolkit_en.symbol_num() else: nii_warn.f_die( "Error: symbol_num cannot handle {:s}".format(flag_lang)) return 0
def __init__(self, buf_dataseq_length, batch_size): """ SamplerBlockShuffleByLength(buf_dataseq_length, batch_size) args ---- buf_dataseq_length: list or np.array of int, length of each data in a dataset batch_size: int, batch_size """ if batch_size == 1: mes = "Sampler block shuffle by length requires batch-size>1" nii_warn.f_die(mes) # hyper-parameter, just let block_size = batch_size * 3 self.m_block_size = batch_size * 4 # idx sorted based on sequence length self.m_idx = np.argsort(buf_dataseq_length) return
def __init__(self, config_path): """ initialization """ # get configuration path self.m_config_path = None if os.path.isfile(config_path): self.m_config_path = config_path else: nii_display.f_die("Cannot find %s" % (config_path), 'error') # path configuration file self.m_config = self.f_parse() if self.m_config is None: nii_display.f_die("Fail to parse %s" % (config_path), 'error') # done return
def __getitem__(self, i): """ getitem from the corresponding subcorpus """ # for example, data1 = [a], data2 = [b, c] # self.len_buffer = [1, 2] # self.len_top = [1, 3] # self.len_bot = [0, 1] # __getitem__(0) -> data1[0-0] = a # __getitem__(1) -> data2[1-1] = b # __getitem__(2) -> data2[2-1] = c for idx_u, idx_d, subset in \ zip(self.len_top, self.len_bot, self.datasets): if i < idx_u: return subset.__getitem__(i - idx_d) else: # keep going to the next subset pass nii_warn.f_die("Merge dataset: fatal error in __getitem__") return None
def code2text(codes, flag_lang='EN'): """ Convert text string into code indices input ----- code_seq: numpy arrays of integers flag_lang: string, 'EN': English output ------ text: string """ # convert numpy array backto indices codes_tmp = [int(x) for x in codes] output_text = '' if flag_lang == 'EN': output_text = toolkit_en.code2text(codes_tmp) else: nii_warn.f_die("Error: code2text cannot handle {:s}".format(flag_lang)) return output_text
def load_state_dic(self, state_dic): """ resume training, load the information """ try: if self.seq_num != state_dic['seq_num']: nii_display.f_print( "Number of samples are different \ from previous training", 'error') nii_display.f_die("Please make sure resumed training are \ using the same training/development sets as before") self.loss_mat = state_dic['loss_mat'] self.time_mat = state_dic['time_mat'] self.epoch_num = state_dic['epoch_num'] self.seq_num = state_dic['seq_num'] # since the saved cur_epoch has been finished self.cur_epoch = state_dic['cur_epoch'] + 1 self.best_error = state_dic['best_error'] self.best_epoch = state_dic['best_epoch'] self.seq_names = {} except KeyError: nii_display.f_die("Invalid op_process_monitor state_dic")
def f_loss_check(loss_module, model_type=None): """ f_loss_check(pt_model) Check whether the loss module contains all the necessary keywords Args: ---- loss_module, a class model_type, a str or None Return: ------- """ nii_display.f_print("Loss check") if model_type in nii_nn_manage_conf.loss_method_keywords_bags: keywords_bag = nii_nn_manage_conf.loss_method_keywords_bags[model_type] else: keywords_bag = nii_nn_manage_conf.loss_method_keywords_default for tmpkey in keywords_bag.keys(): flag_mandatory, mes = keywords_bag[tmpkey] # mandatory keywords if flag_mandatory: if not hasattr(loss_module, tmpkey): nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes)) nii_display.f_die("[Error]: found no %s in Loss" % (tmpkey)) else: # no need to print other information here pass #print("[OK]: %s found" % (tmpkey)) else: if not hasattr(loss_module, tmpkey): # no need to print other information here pass #print("[OK]: %s is ignored, %s" % (tmpkey, mes)) else: print("[OK]: use %s, %s" % (tmpkey, mes)) # done nii_display.f_print("Loss check done\n") return
def f_retrieve(self, keyword, section_name=None, config_type=None): """ f_retrieve(self, keyword, section_name=None, config_type=None) retrieve the keyword from config file Return: value: string, int, float Parameters: keyword: 'keyword' to be retrieved section: which section is this keyword in the config. None will search all the config sections and return the first config_type: which can be 'int', 'float', or None. None will return the value as a string """ tmp_value = None if section_name is None: # if section is not given, search all the sections for section_name in self.m_config.sections(): tmp_value = self.f_retrieve(keyword, section_name, \ config_type) if tmp_value is not None: break elif section_name in self.m_config.sections() or \ section_name == 'DEFAULT': tmp_sec = self.m_config[section_name] # search a specific section if config_type == 'int': tmp_value = tmp_sec.getint(keyword, fallback=None) elif config_type == 'float': tmp_value = tmp_sec.getfloat(keyword, fallback=None) elif config_type == 'bool': tmp_value = tmp_sec.getboolean(keyword, fallback=None) else: tmp_value = tmp_sec.get(keyword, fallback=None) else: nii_display.f_die("Unknown section %s" % (section_name)) return tmp_value
def f_log_data_len(self, file_name, t_len, t_reso): """ f_log_data_len(file_name, t_len, t_reso): Log down the length of the data file. When comparing the different input/output features for the same file_name, only keep the shortest length """ # the length for the sequence with the fast tempoeral rate # For example, acoustic-feature -> waveform 16kHz, # if acoustic-feature is one frame per 5ms, # tmp_len = acoustic feature frame length * (5 * 16) # where t_reso = 5*16 is the up-sampling rate of acoustic feature tmp_len = t_len * t_reso # save length when have not read the file if file_name not in self.m_data_length: self.m_data_length[file_name] = tmp_len # check length if t_len == 1: # if this is an utterance-level feature, it has only 1 frame pass elif self.f_valid_len(self.m_data_length[file_name], tmp_len, \ nii_dconf.data_seq_min_length): # if the difference in length is small if self.m_data_length[file_name] > tmp_len: self.m_data_length[file_name] = tmp_len else: nii_warn.f_print("Sequence length mismatch:", 'error') self.f_check_specific_data(file_name) nii_warn.f_print("Please the above features", 'error') nii_warn.f_die("Possible invalid data %s" % (file_name)) # adjust the length so that, when reso is used, # the sequence length will be N * reso tmp = self.m_data_length[file_name] self.m_data_length[file_name] = self.f_adjust_len(tmp) return
def __init__(self, model, args): """ Initialize an optimizer over model.parameters() """ # check valildity of model if not hasattr(model, "parameters"): nii_warn.f_print("model is not torch.nn", "error") nii_warn.f_die("Error in creating OptimizerWrapper") # set optimizer type self.op_flag = args.optimizer self.lr = args.lr self.l2_penalty = args.l2_penalty # grad clip norm is directly added in nn_manager self.grad_clip_norm = args.grad_clip_norm # create optimizer if self.op_flag == "Adam": if self.l2_penalty > 0: self.optimizer = torch_optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.l2_penalty) else: self.optimizer = torch_optim.Adam(model.parameters(), lr=self.lr) else: nii_warn.f_print("%s not availabel" % (self.op_flag), "error") nii_warn.f_die("Please change optimizer") # number of epochs self.epochs = args.epochs self.no_best_epochs = args.no_best_epochs # lr scheduler self.lr_scheduler = nii_lr_scheduler.LRScheduler(self.optimizer, args) return
def f_check_specific_data(self, file_name): """ check the data length of a specific file """ tmp_dirs = self.m_input_dirs.copy() tmp_exts = self.m_input_exts.copy() tmp_dims = self.m_input_dims.copy() tmp_reso = self.m_input_reso.copy() tmp_dirs.extend(self.m_output_dirs) tmp_exts.extend(self.m_output_exts) tmp_dims.extend(self.m_output_dims) tmp_reso.extend(self.m_output_reso) # loop over each input/output feature type for t_dir, t_ext, t_dim, t_res in \ zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso): file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) if not nii_io_tk.file_exist(file_path): nii_warn.f_die("%s not found" % (file_path)) else: t_len = self.f_length_data(file_path) // t_dim print("%s, length %d, dim %d, reso: %d" % \ (file_path, t_len, t_dim, t_res)) return
def f_model_check(pt_model, model_type=None): """ f_model_check(pt_model) Check whether the model contains all the necessary keywords Args: ---- pt_model: a Pytorch model model_type_flag: str or None, a flag indicating the type of network Return: ------- """ nii_display.f_print("Model check:") if model_type in nii_nn_manage_conf.nn_model_keywords_bags: keywords_bag = nii_nn_manage_conf.nn_model_keywords_bags[model_type] else: keywords_bag = nii_nn_manage_conf.nn_model_keywords_default for tmpkey in keywords_bag.keys(): flag_mandatory, mes = keywords_bag[tmpkey] # mandatory keywords if flag_mandatory: if not hasattr(pt_model, tmpkey): nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes)) nii_display.f_die("[Error]: found no %s in Model" % (tmpkey)) else: print("[OK]: %s found" % (tmpkey)) else: if not hasattr(pt_model, tmpkey): print("[OK]: %s is ignored, %s" % (tmpkey, mes)) else: print("[OK]: use %s, %s" % (tmpkey, mes)) # done nii_display.f_print("Model check done\n") return
def f_putitem(self, output_data, save_dir, data_infor_str): """ """ # Change the dimension to (length, dim) if output_data.ndim == 3 and output_data.shape[0] == 1: # When input data is (batchsize=1, length, dim) output_data = output_data[0] elif output_data.ndim == 2 and output_data.shape[0] == 1: # When input data is (batchsize=1, length) output_data = np.expand_dims(output_data[0], -1) else: nii_warn.f_print("Output data format not supported.", "error") nii_warn.f_print("Format is not (batch, len, dim)", "error") nii_warn.f_die("Please use batch_size = 1 in generation") # Save output if output_data.shape[1] != self.m_output_all_dim: nii_warn.f_print("Output data dim != expected dim", "error") nii_warn.f_print("Output:%d" % (output_data.shape[1]), \ "error") nii_warn.f_print("Expected:%d" % (self.m_output_all_dim), \ "error") nii_warn.f_die("Please check configuration") if not os.path.isdir(save_dir): try: os.mkdir(save_dir) except OSError: nii_warn.f_die("Cannot carete {}".format(save_dir)) # read the sentence information tmp_seq_info = nii_seqinfo.SeqInfo() tmp_seq_info.parse_from_str(data_infor_str) # write the data file_name = tmp_seq_info.seq_tag() s_dim = 0 e_dim = 0 for t_ext, t_dim in zip(self.m_output_exts, self.m_output_dims): e_dim = s_dim + t_dim file_path = nii_str_tk.f_realpath(save_dir, file_name, t_ext) self.f_write_data(output_data[:, s_dim:e_dim], file_path) return
def f_inference_wrapper(args, pt_model, device, \ test_dataset_wrapper, checkpoint): """ Wrapper for inference """ # prepare dataloader test_data_loader = test_dataset_wrapper.get_loader() test_seq_num = test_dataset_wrapper.get_seq_num() test_dataset_wrapper.print_info() # cuda device if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: nii_display.f_print("DataParallel for inference is not implemented", 'warning') nii_display.f_print("\nUse single GPU: %s\n" % \ (torch.cuda.get_device_name(device))) # print the network pt_model.to(device, dtype=nii_dconf.d_dtype) nii_nn_tools.f_model_show(pt_model) # load trained model parameters from checkpoint cp_names = nii_nn_manage_conf.CheckPointKey() if type(checkpoint) is dict and cp_names.state_dict in checkpoint: pt_model.load_state_dict(checkpoint[cp_names.state_dict]) else: pt_model.load_state_dict(checkpoint) # start generation nii_display.f_print("Start inference (generation):", 'highlight') pt_model.eval() with torch.no_grad(): for _, (data_in, data_tar, data_info, idx_orig) in \ enumerate(test_data_loader): # send data to device and convert data type data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # compute output start_time = time.time() # in case the model defines inference function explicitly if hasattr(pt_model, "inference"): infer_func = pt_model.inference else: infer_func = pt_model.forward if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder if args.model_forward_with_file_name: data_gen = infer_func(data_in, data_tar, data_info) else: data_gen = infer_func(data_in, data_tar) else: if args.model_forward_with_file_name: data_gen = infer_func(data_in, data_info) else: data_gen = infer_func(data_in) time_cost = time.time() - start_time # average time for each sequence when batchsize > 1 time_cost = time_cost / len(data_info) if data_gen is None: nii_display.f_print("No output saved: %s" % (str(data_info)),\ 'warning') for idx, seq_info in enumerate(data_info): _ = nii_op_display_tk.print_gen_info(seq_info, time_cost) continue else: try: data_gen = pt_model.denormalize_output(data_gen) data_gen_np = data_gen.to("cpu").numpy() except AttributeError: mes = "Output data is not torch.tensor. Please check " mes += "model.forward or model.inference" nii_display.f_die(mes) # save output (in case batchsize > 1, ) for idx, seq_info in enumerate(data_info): _ = nii_op_display_tk.print_gen_info(seq_info, time_cost) test_dataset_wrapper.putitem(data_gen_np[idx:idx+1],\ args.output_dir, \ seq_info) # done for # done with # nii_display.f_print("Generated data to %s" % (args.output_dir)) # finish up if necessary if hasattr(pt_model, "finish_up_inference"): pt_model.finish_up_inference() # done return
def f_run_one_epoch(args, pt_model, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer = None, \ target_norm_method = None): """ f_run_one_epoch: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer: torch optimizer or None if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): ############# # prepare ############# # idx_orig is the original idx in the dataset # which can be different from data_idx when shuffle = True #idx_orig = idx_orig.numpy()[0] #data_seq_info = data_info[0] # send data to device if optimizer is not None: optimizer.zero_grad() ############ # compute output ############ data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model(data_in, data_tar_tm, data_info) else: data_gen = pt_model(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model(data_in) ##################### # compute loss and do back propagate ##################### # Two cases # 1. if loss is defined as pt_model.loss, then let the users do # normalization inside the pt_mode.loss # 2. if loss_wrapper is defined as a class independent from model # there is no way to normalize the data inside the loss_wrapper # because the normalization weight is saved in pt_model if hasattr(pt_model, 'loss'): # case 1, pt_model.loss is available if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) else: data_tar = [] loss_computed = pt_model.loss(data_gen, data_tar) else: # case 2, loss is defined independent of pt_model if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: normed_target = [] # return the loss from loss_wrapper # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # which contain multiple loss and flags indicating whether # the corresponding loss should be taken into consideration # for early stopping # or # loss_computed may be simply a tensor loss loss_computed = loss_wrapper.compute(data_gen, normed_target) loss_values = [0] # To handle cases where there are multiple loss functions # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # loss: sum of [loss_1, loss_2, ...], for backward() # loss_values: [loss_1.item(), loss_2.item() ..], for logging # loss_flags: [True/False, ...], for logging, # whether loss_n is used for early stopping # when loss_computed is loss # loss: loss # los_vals: [loss.item()] # loss_flags: [True] loss, loss_values, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # Back-propgation using the summed loss if optimizer is not None: # backward propagation loss.backward() # apply gradient clip if args.grad_clip_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( pt_model.parameters(), args.grad_clip_norm) # update parameters optimizer.step() # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_values, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # Save intermediate model for every n mini-batches (optional). # Note that if we re-start trainining with this intermediate model, # the data will start from the 1st sample, not the one where we stopped if args.save_model_every_n_minibatches > 0 \ and (data_idx+1) % args.save_model_every_n_minibatches == 0 \ and optimizer is not None and data_idx > 0: cp_names = nii_nn_manage_conf.CheckPointKey() tmp_model_name = nii_nn_tools.f_save_epoch_name( args, epoch_idx, '_{:05d}'.format(data_idx + 1)) # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.optimizer: optimizer.state_dict() } torch.save(tmp_dic, tmp_model_name) # loop done return
def __getitem__(self, idx): """ __getitem__(self, idx): Return input, output For test set data, output can be None """ try: tmp_seq_info = self.m_seq_info[idx] except IndexError: nii_warn.f_die("Sample %d is not in seq_info" % (idx)) # file_name file_name = tmp_seq_info.seq_tag() # For input data input_reso = self.m_input_reso[0] seq_len = int(tmp_seq_info.seq_length() // input_reso) s_idx = (tmp_seq_info.seq_start_pos() // input_reso) e_idx = s_idx + seq_len input_dim = self.m_input_all_dim in_data = np.zeros([seq_len, input_dim], dtype=nii_dconf.h_dtype) s_dim = 0 e_dim = 0 # loop over each feature type for t_dir, t_ext, t_dim, t_res in \ zip(self.m_input_dirs, self.m_input_exts, \ self.m_input_dims, self.m_input_reso): e_dim = s_dim + t_dim # get file path and load data file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) try: tmp_d = self.f_load_data(file_path, t_dim) except IOError: nii_warn.f_die("Cannot find %s" % (file_path)) # write data if tmp_d.shape[0] == 1: # input data has only one frame, duplicate if tmp_d.ndim > 1: in_data[:,s_dim:e_dim] = tmp_d[0,:] elif t_dim == 1: in_data[:,s_dim] = tmp_d else: nii_warn.f_die("Dimension wrong %s" % (file_path)) else: # normal case if tmp_d.ndim > 1: # write multi-dimension data in_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:] elif t_dim == 1: # write one-dimension data in_data[:,s_dim] = tmp_d[s_idx:e_idx] else: nii_warn.f_die("Dimension wrong %s" % (file_path)) s_dim = e_dim # load output data if self.m_output_dirs: seq_len = tmp_seq_info.seq_length() s_idx = tmp_seq_info.seq_start_pos() e_idx = s_idx + seq_len out_dim = self.m_output_all_dim out_data = np.zeros([seq_len, out_dim], \ dtype = nii_dconf.h_dtype) s_dim = 0 e_dim = 0 for t_dir, t_ext, t_dim in zip(self.m_output_dirs, \ self.m_output_exts, \ self.m_output_dims): e_dim = s_dim + t_dim # get file path and load data file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) try: tmp_d = self.f_load_data(file_path, t_dim) except IOError: nii_warn.f_die("Cannot find %s" % (file_path)) if tmp_d.shape[0] == 1: if tmp_d.ndim > 1: out_data[:,s_dim:e_dim] = tmp_d[0,:] elif t_dim == 1: out_data[:,s_dim]=tmp_d else: nii_warn.f_die("Dimension wrong %s" % (file_path)) else: if tmp_d.ndim > 1: out_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:] elif t_dim == 1: out_data[:,s_dim]=tmp_d[s_idx:e_idx] else: nii_warn.f_die("Dimension wrong %s" % (file_path)) s_dim = s_dim + t_dim else: out_data = [] return in_data, out_data, tmp_seq_info.print_to_str(), idx
def f_calculate_stats(self, flag_cal_data_len, flag_cal_mean_std): """ f_calculate_stats Log down the number of time steps for each file Calculate the mean/std """ # check #if not self.m_output_dirs: # nii_warn.f_print("Calculating mean/std", 'error') # nii_warn.f_die("But output_dirs is not provided") # prepare the directory, extension, and dimensions tmp_dirs = self.m_input_dirs.copy() tmp_exts = self.m_input_exts.copy() tmp_dims = self.m_input_dims.copy() tmp_reso = self.m_input_reso.copy() tmp_norm = self.m_input_norm.copy() tmp_dirs.extend(self.m_output_dirs) tmp_exts.extend(self.m_output_exts) tmp_dims.extend(self.m_output_dims) tmp_reso.extend(self.m_output_reso) tmp_norm.extend(self.m_output_norm) # starting dimension of one type of feature s_dim = 0 # ending dimension of one type of feature e_dim = 0 # loop over each input/output feature type for t_dir, t_ext, t_dim, t_reso, t_norm in \ zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso, tmp_norm): s_dim = e_dim e_dim = s_dim + t_dim t_cnt = 0 mean_i, var_i = np.zeros([t_dim]), np.zeros([t_dim]) # loop over all the data for file_name in self.m_file_list: # get file path file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) if not nii_io_tk.file_exist(file_path): nii_warn.f_die("%s not found" % (file_path)) # read the length of the data if flag_cal_data_len: t_len = self.f_length_data(file_path) // t_dim self.f_log_data_len(file_name, t_len, t_reso) # accumulate the mean/std recursively if flag_cal_mean_std: t_data = self.f_load_data(file_path, t_dim) # if the is F0 data, only consider voiced data if t_ext in nii_dconf.f0_unvoiced_dic: unvoiced_value = nii_dconf.f0_unvoiced_dic[t_ext] t_data = t_data[t_data > unvoiced_value] # mean_i, var_i, t_cnt will be updated using online # accumulation method mean_i, var_i, t_cnt = nii_stats.f_online_mean_std( t_data, mean_i, var_i, t_cnt) # save mean and std for one feature type if flag_cal_mean_std: # if not normalize this dimension, set mean=0, std=1 if not t_norm: mean_i[:] = 0 var_i[:] = 1 if s_dim < self.m_input_all_dim: self.m_input_mean[s_dim:e_dim] = mean_i std_i = nii_stats.f_var2std(var_i) self.m_input_std[s_dim:e_dim] = std_i else: tmp_s = s_dim - self.m_input_all_dim tmp_e = e_dim - self.m_input_all_dim self.m_output_mean[tmp_s:tmp_e] = mean_i std_i = nii_stats.f_var2std(var_i) self.m_output_std[tmp_s:tmp_e] = std_i if flag_cal_data_len: # create seq_info self.f_log_seq_info() self.f_save_data_len(self.m_data_len_path) if flag_cal_mean_std: self.f_save_mean_std(self.m_ms_input_path, self.m_ms_output_path) # done return
def f_run_one_epoch(args, pt_model, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer = None, \ target_norm_method = None): """ f_run_one_epoch: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer: torch optimizer or None if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples pbar = tqdm(data_loader) epoch_num = monitor.get_max_epoch() for data_idx, (data_in, data_tar, data_info, idx_orig) in enumerate(pbar): pbar.set_description("Epoch: {}/{}".format(epoch_idx, epoch_num)) # idx_orig is the original idx in the dataset # which can be different from data_idx when shuffle = True #idx_orig = idx_orig.numpy()[0] #data_seq_info = data_info[0] # send data to device if optimizer is not None: optimizer.zero_grad() # compute data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model(data_in, data_tar_tm, data_info) else: data_gen = pt_model(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model(data_in) # compute loss and do back propagate loss_vals = [0] if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) # return the loss from loss_wrapper # loss_computed may be [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # which contain multiple loss and flags indicating whether # the corresponding loss should be taken into consideration # for early stopping # or # loss_computed may be simply a tensor loss loss_computed = loss_wrapper.compute(data_gen, normed_target) # To handle cases where there are multiple loss functions # when loss_comptued is [[loss_1, loss_2, ...],[flag_1, flag_2,.]] # loss: sum of [loss_1, loss_2, ...], for backward() # loss_vals: [loss_1.item(), loss_2.item() ..], for logging # loss_flags: [True/False, ...], for logging, # whether loss_n is used for early stopping # when loss_computed is loss # loss: loss # los_vals: [loss.item()] # loss_flags: [True] loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # Back-propgation using the summed loss if optimizer is not None: loss.backward() optimizer.step() # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done pbar.close() return
def f_run_one_epoch(args, pt_model, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer = None): """ f_run_one_epoch: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model: pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer: torch optimizer or None if None, the back propgation will be skipped (for developlement set) """ # timer start_time = time.time() # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): # idx_orig is the original idx in the dataset # which can be different from data_idx when shuffle = True #idx_orig = idx_orig.numpy()[0] #data_seq_info = data_info[0] # send data to device if optimizer is not None: optimizer.zero_grad() # compute data_in = data_in.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) data_gen = pt_model(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but no data_tar is not loaded") else: # normal case for model.forward(input) data_gen = pt_model(data_in) # compute loss and do back propagate loss_value = 0 if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here normed_target = pt_model.normalize_target(data_tar) loss = loss_wrapper.compute(data_gen, normed_target) loss_value = loss.item() if optimizer is not None: loss.backward() optimizer.step() # log down process information end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): monitor.log_loss(loss_value / batchsize, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done return
def __init__(self, dataset_name, \ list_file_list, \ list_input_dirs, input_exts, input_dims, input_reso, \ input_norm, \ list_output_dirs, output_exts, output_dims, output_reso, \ output_norm, \ stats_path, \ data_format = nii_dconf.h_dtype_str, \ params = None, \ truncate_seq = None, \ min_seq_len = None, save_mean_std = True, \ wav_samp_rate = None, \ flag_lang = 'EN', \ way_to_merge = 'concatenate', global_arg = None, dset_config = None, augment_funcs = None, transform_funcs = None): """ Signature is similar to default_io.NIIDataSetLoader. file_list, input_dirs, and output_dirs are different. One additional optional argument is way_to_merge. Args ---- data_set_name: a string to name this dataset this will be used to name the statistics files such as the mean/std for this dataset list_file_list: a list of file_name path list_input_dirs: a list of lists of dirs for input features input_exts: a list of input feature name extentions input_dims: a list of input feature dimensions input_reso: a list of input feature temporal resolution, or None input_norm: a list of bool, whether normalize input feature or not list_output_dirs: a list of lists of dirs for output features output_exts: a list of output feature name extentions output_dims: a list of output feature dimensions output_reso: a list of output feature temporal resolution, or None output_norm: a list of bool, whether normalize target feature or not stats_path: path to the directory of statistics(mean/std) data_format: method to load the data '<f4' (default): load data as float32m little-endian 'htk': load data as htk format params: parameter for torch.utils.data.DataLoader truncate_seq: None or int, truncate data sequence into smaller truncks truncate_seq > 0 specifies the trunck length min_seq_len: None (default) or int, minimum length of an utterance utterance shorter than min_seq_len will be ignored save_mean_std: bool, True (default): save mean and std wav_samp_rate: None (default) or int, if input data has waveform, please set sampling rate. It is used by _data_writer flag_lang: str, 'EN' (default), if input data has text, text will be converted into code indices. flag_lang indicates the language for the text processer. It is used by _data_reader wav_to_merge: string, 'concatenate' (default) or 'merge' 'concatenate': simply concatenate multiple corpora 'merge': create minibatch by merging data from each copora global_arg: argument parser returned by arg_parse.f_args_parsed() default None augment_funcs: None, or list of functions for data augmentation transform_funcs: None, or list of functions for data transformation Methods ------- get_loader(): return a torch.util.data.DataLoader get_dataset(): return a torch.util.data.DataSet """ # check whether input_dirs and output_dirs are lists if type(list_input_dirs[0]) is list and \ type(list_output_dirs[0]) is list and \ type(list_file_list) is list and \ len(list_input_dirs) == len(list_output_dirs) and \ len(list_input_dirs) == len(list_file_list): pass else: mes = "NII_MergeDataSetLoader: input_dirs, output_dirs, " mes += "and file_list should be list of lists. " mes += "They should have equal length. But we have:" mes += "{:s}\n{:s}\n{:s}".format( str(list_input_dirs), str(list_output_dirs), str(list_file_list)) nii_warn.f_die(mes) if type(dataset_name) is list: if len(dataset_name) != len(list_input_dirs): mes = "dataset_name should have {:d} elements. ".format( len(list_file_list)) mes += "But we have: {:s}".format(str(dataset_name)) nii_warn.f_die(mes) elif len(list(set(dataset_name))) != len(list_input_dirs): mes = "dataset_name has duplicated elements: {:s}".format( str(dataset_name)) nii_warn.f_die(mes) else: tmp_dnames = dataset_name else: tmp_dnames = [dataset_name + '_sub_{:d}'.format(idx) \ for idx in np.arange(len(list_input_dirs))] # create individual datasets lst_dset = [] for sub_input_dirs, sub_output_dirs, sub_file_list, tmp_name in \ zip(list_input_dirs, list_output_dirs, list_file_list, tmp_dnames): lst_dset.append( nii_default_dset.NIIDataSetLoader( tmp_name, sub_file_list, sub_input_dirs, input_exts, input_dims, input_reso, \ input_norm, \ sub_output_dirs, output_exts, output_dims, output_reso, \ output_norm, \ stats_path, data_format, params, truncate_seq, min_seq_len, save_mean_std, wav_samp_rate, flag_lang, global_arg)) # list of the datasets self.m_datasets = lst_dset self.way_to_merge = way_to_merge # create data loader if way_to_merge == 'concatenate': # to create DataLoader, we need the pytorch.dataset py_datasets = ConcatDataset([x.get_dataset() for x in lst_dset]) # legacy implementation, no need to use #### # Although members in l_dset have Dataloader, we need to # create a dataloder for the concatenate dataset ### if params is None: tmp_params = nii_dconf.default_loader_conf else: tmp_params = params.copy() # save parameters self.m_params = tmp_params.copy() # if 'sampler' in tmp_params: tmp_sampler = None if tmp_params['sampler'] == nii_sampler_fn.g_str_sampler_bsbl: if 'batch_size' in tmp_params: # initialize the sampler tmp_sampler = nii_sampler_fn.SamplerBlockShuffleByLen( py_datasets.f_get_seq_len_list(), tmp_params['batch_size']) # turn off automatic shuffle tmp_params['shuffle'] = False else: nii_warn.f_die("Sampler requires batch size > 1") tmp_params['sampler'] = tmp_sampler # collate function if 'batch_size' in tmp_params and tmp_params['batch_size'] > 1: # use customize_collate to handle data with unequal length # we cannot use default collate_fn collate_fn = nii_collate_fn.customize_collate else: collate_fn = None # use default DataLoader self.m_loader = torch.utils.data.DataLoader( py_datasets, collate_fn=collate_fn, **tmp_params) else: # sample mini-batches of equal size from each sub dataset # use specific dataloader self.m_loader = merge_loader(lst_dset) self.m_params = lst_dset[0].get_loader_params() return
def __init__(self, dataset_name, \ file_list, \ input_dirs, input_exts, input_dims, input_reso, \ input_norm, \ output_dirs, output_exts, output_dims, output_reso, \ output_norm, \ stats_path, \ data_format = '<f4', \ truncate_seq = None, \ min_seq_len = None, \ save_mean_std = True, \ wav_samp_rate = None): """ Args: dataset_name: name of this data set file_list: a list of file name strings (without extension) input_dirs: a list of dirs from each input feature is loaded input_exts: a list of input feature name extentions input_dims: a list of input feature dimensions input_reso: a list of input feature temporal resolutions output_dirs: a list of dirs from each output feature is loaded output_exts: a list of output feature name extentions output_dims: a list of output feature dimensions output_reso: a list of output feature temporal resolutions stat_path: path to the directory that saves mean/std, utterance length data_format: method to load the data '<f4' (default): load data as float32m little-endian 'htk': load data as htk format truncate_seq: None or int, truncate sequence into truncks. truncate_seq > 0 specifies the trunck length """ # initialization self.m_set_name = dataset_name self.m_file_list = file_list self.m_input_dirs = input_dirs self.m_input_exts = input_exts self.m_input_dims = input_dims self.m_output_dirs = output_dirs self.m_output_exts = output_exts self.m_output_dims = output_dims if len(self.m_input_dirs) != len(self.m_input_exts) or \ len(self.m_input_dirs) != len(self.m_input_dims): nii_warn.f_print("Input dirs, exts, dims, unequal length", 'error') nii_warn.f_print(str(self.m_input_dirs), 'error') nii_warn.f_print(str(self.m_input_exts), 'error') nii_warn.f_print(str(self.m_input_dims), 'error') nii_warn.f_die("Please check input dirs, exts, dims") if len(self.m_output_dims) != len(self.m_output_exts) or \ (self.m_output_dirs and \ len(self.m_output_dirs) != len(self.m_output_exts)): nii_warn.f_print("Output dirs, exts, dims, unequal length", \ 'error') nii_warn.f_die("Please check output dirs, exts, dims") # fill in m_*_reso and m_*_norm def _tmp_f(list2, default_value, length): if list2 is None: return [default_value for x in range(length)] else: return list2 self.m_input_reso = _tmp_f(input_reso, 1, len(input_dims)) self.m_input_norm = _tmp_f(input_norm, True, len(input_dims)) self.m_output_reso = _tmp_f(output_reso, 1, len(output_dims)) self.m_output_norm = _tmp_f(output_norm, True, len(output_dims)) if len(self.m_input_reso) != len(self.m_input_dims): nii_warn.f_die("Please check input_reso") if len(self.m_output_reso) != len(self.m_output_dims): nii_warn.f_die("Please check output_reso") if len(self.m_input_norm) != len(self.m_input_dims): nii_warn.f_die("Please check input_norm") if len(self.m_output_norm) != len(self.m_output_dims): nii_warn.f_die("Please check output_norm") # dimensions self.m_input_all_dim = sum(self.m_input_dims) self.m_output_all_dim = sum(self.m_output_dims) self.m_io_dim = self.m_input_all_dim + self.m_output_all_dim self.m_truncate_seq = truncate_seq self.m_min_seq_len = min_seq_len self.m_save_ms = save_mean_std # in case there is waveform data in input or output features self.m_wav_sr = wav_samp_rate # sanity check on resolution configuration # currently, only input features can have different reso, # and the m_input_reso must be the same for all input features if any([x != self.m_input_reso[0] for x in self.m_input_reso]): nii_warn.f_print("input_reso: %s" % (str(self.m_input_reso)),\ 'error') nii_warn.f_print("NIIDataSet not support", 'error', end='') nii_warn.f_die(" different input_reso") if any([x != 1 for x in self.m_output_reso]): nii_warn.f_print("NIIDataSet only supports", 'error', end='') nii_warn.f_die(" output_reso = [1, 1, ... 1]") self.m_single_reso = self.m_input_reso[0] # To make sure that target waveform length is exactly equal # to the up-sampled sequence length # self.m_truncate_seq must be changed to be N * up_sample if self.m_truncate_seq is not None: # assume input resolution is the same self.m_truncate_seq = self.f_adjust_len(self.m_truncate_seq) # method to load/write raw data if data_format == '<f4': self.f_load_data = _data_reader self.f_length_data = _data_len_reader self.f_write_data = lambda x, y: _data_writer(x, y, \ self.m_wav_sr) else: nii_warn.f_print("Unsupported dtype %s" % (data_format)) nii_warn.f_die("Only supports np.float32 <f4") # check the validity of data self.f_check_file_list() # log down statiscs # 1. length of each data utterance # 2. mean / std of feature feature file def get_name(stats_path, set_name, file_name): tmp = set_name + '_' + file_name return os.path.join(stats_path, tmp) self.m_ms_input_path = get_name(stats_path, self.m_set_name, \ nii_dconf.mean_std_i_file) self.m_ms_output_path = get_name(stats_path, self.m_set_name, \ nii_dconf.mean_std_o_file) self.m_data_len_path = get_name(stats_path, self.m_set_name, \ nii_dconf.data_len_file) # initialize data length and mean /std flag_cal_len = self.f_init_data_len_stats(self.m_data_len_path) flag_cal_mean_std = self.f_init_mean_std(self.m_ms_input_path, self.m_ms_output_path) # if data information is not available, read it again from data if flag_cal_len or flag_cal_mean_std: self.f_calculate_stats(flag_cal_len, flag_cal_mean_std) # check if self.__len__() < 1: nii_warn.f_print("Fail to load any data", "error") nii_warn.f_die("Please check configuration") # done return
def f_run_one_epoch_WGAN( args, pt_model_G, pt_model_D, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer_G = None, optimizer_D = None, \ target_norm_method = None): """ f_run_one_epoch_WGAN: similar to f_run_one_epoch_GAN, but for WGAN """ # timer start_time = time.time() # number of critic (default 5) if hasattr(args, "wgan-critic-num"): num_critic = args.wgan_critic_num else: num_critic = 5 # clip value if hasattr(args, "wgan-clamp"): wgan_clamp = args.wgan_clamp else: wgan_clamp = 0.01 # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): # send data to device if optimizer_G is not None: optimizer_G.zero_grad() if optimizer_D is not None: optimizer_D.zero_grad() # prepare data if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model_G.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: nii_display.f_die("target data is required") # to device (we assume noise will be generated by the model itself) # here we only provide external condition data_in = data_in.to(device, dtype=nii_dconf.d_dtype) ############################ # Update Discriminator ############################ # train with real pt_model_D.zero_grad() d_out_real = pt_model_D(data_tar) errD_real = loss_wrapper.compute_gan_D_real(d_out_real) if optimizer_D is not None: errD_real.backward() d_out_real_mean = d_out_real.mean() # train with fake # generate sample if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model_G(data_in, data_tar_tm, data_info) else: data_gen = pt_model_G(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model_G(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model_G(data_in) # data_gen.detach() is required # https://github.com/pytorch/examples/issues/116 d_out_fake = pt_model_D(data_gen.detach()) errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) if optimizer_D is not None: errD_fake.backward() d_out_fake_mean = d_out_fake.mean() errD = errD_real + errD_fake if optimizer_D is not None: optimizer_D.step() # clip weights of discriminator for p in pt_model_D.parameters(): p.data.clamp_(-wgan_clamp, wgan_clamp) ############################ # Update Generator ############################ pt_model_G.zero_grad() d_out_fake_for_G = pt_model_D(data_gen) errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) errG = errG_gan + errG_aux # only update after num_crictic iterations on discriminator if data_idx % num_critic == 0 and optimizer_G is not None: errG.backward() optimizer_G.step() d_out_fake_for_G_mean = d_out_fake_for_G.mean() # construct the loss for logging and early stopping # only use errG_aux for early-stopping loss_computed = [[ errG_aux, errG_gan, errD_real, errD_fake, d_out_real_mean, d_out_fake_mean, d_out_fake_for_G_mean ], [True, False, False, False, False, False, False]] # to handle cases where there are multiple loss functions loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss( loss_computed) # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done return
def f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, \ optimizer_G_wrapper, optimizer_D_wrapper, \ train_dataset_wrapper, \ val_dataset_wrapper = None, \ checkpoint_G = None, checkpoint_D = None): """ f_train_wrapper_GAN( args, pt_model_G, pt_model_D, loss_wrapper, device, optimizer_G_wrapper, optimizer_D_wrapper, train_dataset_wrapper, val_dataset_wrapper = None, check_point = None): A wrapper to run the training process Args: args: argument information given by argpase pt_model_G: generator, pytorch model (torch.nn.Module) pt_model_D: discriminator, pytorch model (torch.nn.Module) loss_wrapper: a wrapper over loss functions loss_wrapper.compute_D_real(discriminator_output) loss_wrapper.compute_D_fake(discriminator_output) loss_wrapper.compute_G(discriminator_output) loss_wrapper.compute_G(fake, real) device: torch.device("cuda") or torch.device("cpu") optimizer_G_wrapper: a optimizer wrapper for generator (defined in op_manager.py) optimizer_D_wrapper: a optimizer wrapper for discriminator (defined in op_manager.py) train_dataset_wrapper: a wrapper over training data set (data_io/default_data_io.py) train_dataset_wrapper.get_loader() returns torch.DataSetLoader val_dataset_wrapper: a wrapper over validation data set (data_io/default_data_io.py) it can None. checkpoint_G: a check_point that stores every thing to resume training checkpoint_D: a check_point that stores every thing to resume training """ nii_display.f_print_w_date("Start model training") # get the optimizer optimizer_G_wrapper.print_info() optimizer_D_wrapper.print_info() optimizer_G = optimizer_G_wrapper.optimizer optimizer_D = optimizer_D_wrapper.optimizer epoch_num = optimizer_G_wrapper.get_epoch_num() no_best_epoch_num = optimizer_G_wrapper.get_no_best_epoch_num() # get data loader for training set train_dataset_wrapper.print_info() train_data_loader = train_dataset_wrapper.get_loader() train_seq_num = train_dataset_wrapper.get_seq_num() # get the training process monitor monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) # if validation data is provided, get data loader for val set if val_dataset_wrapper is not None: val_dataset_wrapper.print_info() val_data_loader = val_dataset_wrapper.get_loader() val_seq_num = val_dataset_wrapper.get_seq_num() monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) else: monitor_val = None # training log information train_log = '' model_tags = ["_G", "_D"] # prepare for DataParallism if available # pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: nii_display.f_die("data_parallel not implemented for GAN") else: nii_display.f_print("Use single GPU: %s" % \ (torch.cuda.get_device_name(device))) flag_multi_device = False normtarget_f = None pt_model_G.to(device, dtype=nii_dconf.d_dtype) pt_model_D.to(device, dtype=nii_dconf.d_dtype) # print the network nii_display.f_print("Setup generator") f_model_show(pt_model_G) nii_display.f_print("Setup discriminator") f_model_show(pt_model_D) # resume training or initialize the model if necessary cp_names = CheckPointKey() if checkpoint_G is not None or checkpoint_D is not None: for checkpoint, optimizer, pt_model, model_name in \ zip([checkpoint_G, checkpoint_D], [optimizer_G, optimizer_D], [pt_model_G, pt_model_D], ["Generator", "Discriminator"]): nii_display.f_print("For %s" % (model_name)) if type(checkpoint) is dict: # checkpoint # load model parameter and optimizer state if cp_names.state_dict in checkpoint: # wrap the state_dic in f_state_dict_wrapper # in case the model is saved when DataParallel is on pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper( checkpoint[cp_names.state_dict], flag_multi_device)) # load optimizer state if cp_names.optimizer in checkpoint: optimizer.load_state_dict(checkpoint[cp_names.optimizer]) # optionally, load training history if not args.ignore_training_history_in_trained_model: #nii_display.f_print("Load ") if cp_names.trnlog in checkpoint: monitor_trn.load_state_dic(checkpoint[cp_names.trnlog]) if cp_names.vallog in checkpoint and monitor_val: monitor_val.load_state_dic(checkpoint[cp_names.vallog]) if cp_names.info in checkpoint: train_log = checkpoint[cp_names.info] nii_display.f_print("Load check point, resume training") else: nii_display.f_print("Load pretrained model and optimizer") elif checkpoint is not None: # only model status #pt_model.load_state_dict(checkpoint) pt_model.load_state_dict( nii_nn_tools.f_state_dict_wrapper(checkpoint, flag_multi_device)) nii_display.f_print("Load pretrained model") else: nii_display.f_print("No pretrained model") # done for resume training # other variables flag_early_stopped = False start_epoch = monitor_trn.get_epoch() epoch_num = monitor_trn.get_max_epoch() if hasattr(loss_wrapper, "flag_wgan") and loss_wrapper.flag_wgan: f_wrapper_gan_one_epoch = f_run_one_epoch_WGAN else: f_wrapper_gan_one_epoch = f_run_one_epoch_GAN # print _ = nii_op_display_tk.print_log_head() nii_display.f_print_message(train_log, flush=True, end='') # loop over multiple epochs for epoch_idx in range(start_epoch, epoch_num): # training one epoch pt_model_D.train() pt_model_G.train() f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, device, \ monitor_trn, train_data_loader, \ epoch_idx, optimizer_G, optimizer_D, normtarget_f) time_trn = monitor_trn.get_time(epoch_idx) loss_trn = monitor_trn.get_loss(epoch_idx) # if necessary, do validataion if val_dataset_wrapper is not None: # set eval() if necessary if args.eval_mode_for_validation: pt_model_G.eval() pt_model_D.eval() with torch.no_grad(): f_wrapper_gan_one_epoch( args, pt_model_G, pt_model_D, loss_wrapper, \ device, \ monitor_val, val_data_loader, \ epoch_idx, None, None, normtarget_f) time_val = monitor_val.get_time(epoch_idx) loss_val = monitor_val.get_loss(epoch_idx) else: time_val, loss_val = 0, 0 if val_dataset_wrapper is not None: flag_new_best = monitor_val.is_new_best() else: flag_new_best = True # print information train_log += nii_op_display_tk.print_train_info( epoch_idx, time_trn, loss_trn, time_val, loss_val, flag_new_best) # save the best model if flag_new_best: for pt_model, model_tag in \ zip([pt_model_G, pt_model_D], model_tags): tmp_best_name = f_save_trained_name_GAN(args, model_tag) torch.save(pt_model.state_dict(), tmp_best_name) # save intermediate model if necessary if not args.not_save_each_epoch: # save model discrminator and generator for pt_model, optimizer, model_tag in \ zip([pt_model_G, pt_model_D], [optimizer_G, optimizer_D], model_tags): tmp_model_name = f_save_epoch_name_GAN(args, epoch_idx, model_tag) if monitor_val is not None: tmp_val_log = monitor_val.get_state_dic() else: tmp_val_log = None # save tmp_dic = { cp_names.state_dict: pt_model.state_dict(), cp_names.info: train_log, cp_names.optimizer: optimizer.state_dict(), cp_names.trnlog: monitor_trn.get_state_dic(), cp_names.vallog: tmp_val_log } torch.save(tmp_dic, tmp_model_name) if args.verbose == 1: nii_display.f_eprint(str(datetime.datetime.now())) nii_display.f_eprint("Save {:s}".format(tmp_model_name), flush=True) # early stopping if monitor_val is not None and \ monitor_val.should_early_stop(no_best_epoch_num): flag_early_stopped = True break # loop done nii_op_display_tk.print_log_tail() if flag_early_stopped: nii_display.f_print("Training finished by early stopping") else: nii_display.f_print("Training finished") nii_display.f_print("Model is saved to", end='') for model_tag in model_tags: nii_display.f_print("{}".format( f_save_trained_name_GAN(args, model_tag))) return
def f_run_one_epoch_GAN( args, pt_model_G, pt_model_D, loss_wrapper, \ device, monitor, \ data_loader, epoch_idx, optimizer_G = None, optimizer_D = None, \ target_norm_method = None): """ f_run_one_epoch_GAN: run one poech over the dataset (for training or validation sets) Args: args: from argpase pt_model_G: pytorch model (torch.nn.Module) generator pt_model_D: pytorch model (torch.nn.Module) discriminator loss_wrapper: a wrapper over loss function loss_wrapper.compute(generated, target) device: torch.device("cuda") or torch.device("cpu") monitor: defined in op_procfess_monitor.py data_loader: pytorch DataLoader. epoch_idx: int, index of the current epoch optimizer_G: torch optimizer or None, for generator optimizer_D: torch optimizer or None, for discriminator if None, the back propgation will be skipped (for developlement set) target_norm_method: method to normalize target data (by default, use pt_model.normalize_target) """ # timer start_time = time.time() # loop over samples for data_idx, (data_in, data_tar, data_info, idx_orig) in \ enumerate(data_loader): # send data to device if optimizer_G is not None: optimizer_G.zero_grad() if optimizer_D is not None: optimizer_D.zero_grad() # prepare data if isinstance(data_tar, torch.Tensor): data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) # there is no way to normalize the data inside loss # thus, do normalization here if target_norm_method is None: normed_target = pt_model_G.normalize_target(data_tar) else: normed_target = target_norm_method(data_tar) else: nii_display.f_die("target data is required") # to device (we assume noise will be generated by the model itself) # here we only provide external condition data_in = data_in.to(device, dtype=nii_dconf.d_dtype) ############################ # Update Discriminator ############################ # train with real pt_model_D.zero_grad() d_out_real = pt_model_D(data_tar) errD_real = loss_wrapper.compute_gan_D_real(d_out_real) if optimizer_D is not None: errD_real.backward() # this should be given by pt_model_D or loss wrapper #d_out_real_mean = d_out_real.mean() # train with fake # generate sample if args.model_forward_with_target: # if model.forward requires (input, target) as arguments # for example, for auto-encoder & autoregressive model if isinstance(data_tar, torch.Tensor): data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) if args.model_forward_with_file_name: data_gen = pt_model_G(data_in, data_tar_tm, data_info) else: data_gen = pt_model_G(data_in, data_tar_tm) else: nii_display.f_print("--model-forward-with-target is set") nii_display.f_die("but data_tar is not loaded") else: if args.model_forward_with_file_name: # specifcal case when model.forward requires data_info data_gen = pt_model_G(data_in, data_info) else: # normal case for model.forward(input) data_gen = pt_model_G(data_in) # data_gen.detach() is required # https://github.com/pytorch/examples/issues/116 d_out_fake = pt_model_D(data_gen.detach()) errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) if optimizer_D is not None: errD_fake.backward() errD = errD_real + errD_fake if optimizer_D is not None: optimizer_D.step() ############################ # Update Generator ############################ pt_model_G.zero_grad() d_out_fake_for_G = pt_model_D(data_gen) errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) # if defined, calculate auxilliart loss if hasattr(loss_wrapper, "compute_aux"): errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) else: errG_aux = torch.zeros_like(errG_gan) # if defined, calculate feat-matching loss if hasattr(loss_wrapper, "compute_feat_match"): errG_feat = loss_wrapper.compute_feat_match( d_out_real, d_out_fake_for_G) else: errG_feat = torch.zeros_like(errG_gan) # sum loss for generator errG = errG_gan + errG_aux + errG_feat if optimizer_G is not None: errG.backward() optimizer_G.step() # construct the loss for logging and early stopping # only use errG_aux for early-stopping loss_computed = [[errG_aux, errD_real, errD_fake, errG_gan, errG_feat], [True, False, False, False, False]] # to handle cases where there are multiple loss functions _, loss_vals, loss_flags = nii_nn_tools.f_process_loss(loss_computed) # save the training process information to the monitor end_time = time.time() batchsize = len(data_info) for idx, data_seq_info in enumerate(data_info): # loss_value is supposed to be the average loss value # over samples in the the batch, thus, just loss_value # rather loss_value / batchsize monitor.log_loss(loss_vals, loss_flags, \ (end_time-start_time) / batchsize, \ data_seq_info, idx_orig.numpy()[idx], \ epoch_idx) # print infor for one sentence if args.verbose == 1: monitor.print_error_for_batch(data_idx*batchsize + idx,\ idx_orig.numpy()[idx], \ epoch_idx) # # start the timer for a new batch start_time = time.time() # lopp done return