def __init__(self, model_path, keywords, tmpdir): assert model_path.endswith(".pth") self.config = torch.load(model_path, map_location='cpu')['config'] # TODO remove # self.config['exp']['save_dir'] = "/mnt/data/pytorch-kaldi/exp_TIMIT_MLP_FBANK" self.model = model_init(self.config) logger.info(self.model) # TODO GPU decoding self.max_seq_length_train_curr = -1 self.out_dir = os.path.join(self.config['exp']['save_dir'], self.config['exp']['name']) # setup directory for checkpoint saving self.checkpoint_dir = os.path.join(self.out_dir, 'checkpoints') # Save configuration file into checkpoint directory: ensure_dir(self.checkpoint_dir) config_save_path = os.path.join(self.out_dir, 'config.json') with open(config_save_path, 'w') as f: json.dump(self.config, f, indent=4, sort_keys=False) self.epoch, self.global_step, _ = resume_checkpoint(model_path, self.model, logger) self.phoneme_dict = self.config['dataset']['dataset_definition']['phoneme_dict'] graph_dir = make_kaldi_decoding_graph(keywords, tmpdir) self.graph_path = os.path.join(graph_dir, "HCLG.fst") assert os.path.exists(self.graph_path) self.words_path = os.path.join(graph_dir, "words.txt") assert os.path.exists(self.words_path) self.alignment_model_path = os.path.join(graph_dir, "final.mdl") assert os.path.exists(self.alignment_model_path)
def load_warm_start(self, path_to_checkpoint): checkpoint = torch.load(path_to_checkpoint, map_location='cpu') _state_dict = checkpoint['state_dict'] _state_dict_new = {k: v for k, v in _state_dict.items() if 'MLP.layers.5' not in k} _state_dict_new.update({k: v for k, v in self.state_dict().items() if 'MLP.layers.5' in k}) self.load_state_dict(_state_dict_new) logger.info(f"Warm start with model from {path_to_checkpoint}")
def _reduce_lr(self, epoch): for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group['lr']) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group['lr'] = new_lr if self.verbose: logger.info('Epoch {:d}: reducing learning rate' ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
def get_seq_len(self, epoch=None): if epoch is None: epoch = self.last_epoch max_seq_length_train_curr = self.start_seq_len_train if self.increase_seq_length_train: max_seq_length_train_curr = self.start_seq_len_train * ( self.multply_factor_seq_len_train ** epoch) if max_seq_length_train_curr > self.max_seq_length_train: max_seq_length_train_curr = self.max_seq_length_train if self.verbose: logger.info(f"max_seq_length_train_curr set to {max_seq_length_train_curr}") return max_seq_length_train_curr
def setup_run(config, optim_overwrite): set_seed(config['exp']['seed']) torch.backends.cudnn.deterministic = True # Otherwise I got nans for the CTC gradient # TODO remove the data meta info part and move into kaldi folder e.g. # dataset_definition = get_dataset_definition(config['dataset']['name'], config['dataset']['data_use']['train_with']) # config['dataset']['dataset_definition'] = dataset_definition # # if 'lab_phn' in config['dataset']['labels_use']: # phoneme_dict = make_phn_dict(config, dataset_definition, 'lab_phn') # elif 'lab_phnframe' in config['dataset']['labels_use']: # phoneme_dict = make_phn_dict(config, dataset_definition, 'lab_phnframe') # else: # # framewise # phoneme_dict = get_phoneme_dict(config['dataset']['dataset_definition']['phn_mapping_file'], # stress_marks=True, word_position_dependency=True) # # del config['dataset']['dataset_definition']['phn_mapping_file'] # config['dataset']['dataset_definition']['phoneme_dict'] = phoneme_dict model = model_init(config) optimizers, lr_schedulers = optimizer_init(config, model, optim_overwrite) seq_len_scheduler = seq_len_scheduler_init(config) logger.info("".join(["="] * 80)) logger.info("Architecture:") logger.info(model) logger.info("".join(["="] * 80)) metrics = metrics_init(config, model) loss = loss_init(config, model) return model, loss, metrics, optimizers, config, lr_schedulers, seq_len_scheduler
def _filter_samples_by_length(file_names, feature_dict, features_loaded, label_dict, all_labels_loaded, max_sample_len, min_sample_len): samples = {} for file in file_names: _continue = False for feature_name in feature_dict: if file not in features_loaded[feature_name]: logger.info("Skipping {}, not in features".format(file)) _continue = True break for label_name in label_dict: if file not in all_labels_loaded[label_name]: logger.info("Skipping {}, not in labels".format(file)) _continue = True break for feature_name in feature_dict: if type(max_sample_len) == int and \ len(features_loaded[feature_name][file]) > max_sample_len: logger.info( "Skipping {}, feature of size {} too big ( {} expected) ". format(file, len(features_loaded[feature_name][file]), max_sample_len)) _continue = True break if type(min_sample_len) == int and \ min_sample_len > len(features_loaded[feature_name][file]): logger.info( f"Skipping {file}, feature of size " + f"{len(features_loaded[feature_name][file])} too small ( {min_sample_len} expected) " ) _continue = True break if _continue: continue samples[file] = {"features": {}, "labels": {}} for feature_name in feature_dict: samples[file]["features"][feature_name] = features_loaded[ feature_name][file] for label_name in label_dict: samples[file]["labels"][label_name] = all_labels_loaded[ label_name][file] return samples
def summary(self): """ Model summary """ if self.batch_ordering == "NCL": flops, params = profile( self, input={ 'fbank': torch.zeros(((8, 40, sum(self.context) + 50))) }, custom_ops=self.get_custom_ops_for_counting()) elif self.batch_ordering == "TNCL": flops, params = profile( self, input={'fbank': torch.zeros(((1, 8, 40, 1)))}) else: raise NotImplementedError logger.info( 'Trainable parameters: {}'.format(self.trainable_parameters()) + f'\nFLOPS: ~{millify(flops)} ({flops})\n') logger.info(self)
def _convert_from_kaldi_format(self, feature_dict, label_dict): logger.info("Converting features from kaldi features!") main_feat = next(iter(feature_dict)) try: os.makedirs(self.state.dataset_path) except OSError as e: if e.errno == errno.EEXIST: pass else: raise all_labels_loaded = self._load_labels(label_dict) with open(feature_dict[main_feat]["feature_lst_path"], "r") as f: lines = f.readlines() feat_list = lines _sample_index = self.make_feat_chunks( feat_list, feature_dict, label_dict, all_labels_loaded, main_feat, write_info=not self.state.sorted_by_lengh) if self.state.sorted_by_lengh: logger.info('Redoing extracting kaldi features, but sorted!') # for chunk_idx, file_name, stad_idx, end_idx in _sample_index_dict = sorted(_sample_index, key=lambda x: x[3] - x[2]) _files_dict = dict([s.split(" ") for s in feat_list]) sorted_feat_list = [] for chunk_idx, file_name, stad_idx, end_idx in _sample_index_dict: sorted_feat_list.append( f"{file_name} {_files_dict[file_name]}") self.make_feat_chunks(sorted_feat_list, feature_dict, label_dict, all_labels_loaded, main_feat) logger.info('Done extracting kaldi features!')
def resume_checkpoint(resume_path, model, logger, optimizers=None, lr_schedulers=None, seq_len_scheduler=None): if not resume_path.endswith(".pth"): resume_path = folder_to_checkpoint(resume_path) logger.info(f"Loading checkpoint: {resume_path}") checkpoint = torch.load(resume_path, map_location='cpu') if 'dataset_sampler_state' not in checkpoint: checkpoint['dataset_sampler_state'] = None if checkpoint['dataset_sampler_state'] is None: start_epoch = checkpoint['epoch'] + 1 else: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] init_model_state_dict = model.state_dict() for k in list(checkpoint['state_dict'].keys()): if k not in init_model_state_dict: logger.info(f"Removed key {k} from loaded state dict") del checkpoint['state_dict'][k] model.load_state_dict(checkpoint['state_dict']) assert (optimizers is None and lr_schedulers is None) \ or (optimizers is not None and lr_schedulers is not None) if optimizers is not None and lr_schedulers is not None: for opti_name in checkpoint['optimizers']: optimizers[opti_name].load_state_dict( checkpoint['optimizers'][opti_name]) for lr_sched_name in checkpoint['lr_schedulers']: lr_schedulers[lr_sched_name].load_state_dict( checkpoint['lr_schedulers'][lr_sched_name]) logger.info("Checkpoint '{}' (epoch {}) loaded".format( resume_path, start_epoch)) # TODO check checkpoint['dataset_sampler_state'] is none return start_epoch, global_step, checkpoint['dataset_sampler_state']
def main(config_path, load_path, restart, overfit_small_batch, warm_start, optim_overwrite): config = read_json(config_path) check_config(config) if optim_overwrite: optim_overwrite = read_json('cfg/optim_overwrite.json') if load_path is not None: raise NotImplementedError # if resume_path: # TODO # resume_config = torch.load(folder_to_checkpoint(args.resume), map_location='cpu')['config'] # # also the results won't be the same give the different random seeds with different number of draws # del config['exp']['name'] # recursive_update(resume_config, config) # # print("".join(["="] * 80)) # print("Resume with these changes in the config:") # print("".join(["-"] * 80)) # print(jsondiff.diff(config, resume_config, dump=True, dumper=jsondiff.JsonDumper(indent=1))) # print("".join(["="] * 80)) # # config = resume_config # # start_time = datetime.datetime.now().strftime('_%Y%m%d_%H%M%S') # # config['exp']['name'] = config['exp']['name'] + "r-" + start_time # else: save_time = datetime.datetime.now().strftime('_%Y%m%d_%H%M%S') # config['exp']['name'] = config['exp']['name'] + start_time set_seed(config['exp']['seed']) config['exp']['save_dir'] = os.path.abspath(config['exp']['save_dir']) # Output folder creation out_folder = os.path.join(config['exp']['save_dir'], config['exp']['name']) if os.path.exists(out_folder): print( f"Experiement under {out_folder} exists, moving it copying it to backup" ) if os.path.exists(os.path.join(out_folder, "checkpoints")) \ and len(os.listdir(os.path.join(out_folder, "checkpoints"))) > 0: shutil.copytree( out_folder, os.path.join( config['exp']['save_dir'] + "_finished_runs_backup/", config['exp']['name'] + save_time)) # print(os.listdir(os.path.join(out_folder, "checkpoints"))) # resume_path = out_folder # else: if restart: shutil.rmtree(out_folder) os.makedirs(out_folder + '/exp_files') else: os.makedirs(out_folder + '/exp_files') logger.configure_logger(out_folder) check_environment() if nvidia_smi_enabled: # TODO chage criteria or the whole thing git_commit = code_versioning() if 'versioning' not in config: config['versioning'] = {} config['versioning']['git_commit'] = git_commit logger.info("Experiment name : {}".format(out_folder)) logger.info("tensorboard : tensorboard --logdir {}".format( os.path.abspath(out_folder))) model, loss, metrics, optimizers, config, lr_schedulers, seq_len_scheduler = setup_run( config, optim_overwrite) if warm_start is not None: load_warm_start_op = getattr(model, "load_warm_start", None) assert callable(load_warm_start_op) model.load_warm_start(warm_start) # TODO instead of resuming and making a new folder, make a backup and continue in the same folder trainer = Trainer(model, loss, metrics, optimizers, lr_schedulers, seq_len_scheduler, load_path, config, restart_optim=bool(optim_overwrite), do_validation=True, overfit_small_batch=overfit_small_batch) trainer.train()
def evaluate(model, metrics, device, out_folder, exp_name, max_label_length, epoch, dataset_type, data_cache_root, test_with, all_feats_dict, features_use, all_labs_dict, labels_use, phoneme_dict, decoding_info, lab_graph_dir=None, tensorboard_logger=None): model.eval() batch_size = 1 max_seq_length = -1 accumulated_test_metrics = {metric: 0 for metric in metrics} test_data = test_with dataset = get_dataset( dataset_type, data_cache_root, f"{test_data}_{exp_name}", {feat: all_feats_dict[feat] for feat in features_use}, {lab: all_labs_dict[lab] for lab in labels_use}, max_seq_length, model.context_left, model.context_right, normalize_features=True, phoneme_dict=phoneme_dict, max_seq_len=max_seq_length, max_label_length=max_label_length) dataloader = KaldiDataLoader(dataset, batch_size, use_gpu=False, batch_ordering=model.batch_ordering) assert len(dataset) >= batch_size, \ f"Length of test dataset {len(dataset)} too small " \ + f"for batch_size of {batch_size}" n_steps_this_epoch = 0 warned_size = False with Pool(os.cpu_count()) as pool: multip_process = Manager() metrics_q = multip_process.Queue(maxsize=os.cpu_count()) # accumulated_test_metrics_future_list = pool.apply_async(metrics_accumulator, (metrics_q, metrics)) accumulated_test_metrics_future_list = [ pool.apply_async(metrics_accumulator, (metrics_q, metrics)) for _ in range(os.cpu_count()) ] with KaldiOutputWriter(out_folder, test_data, model.out_names, epoch) as writer: with tqdm(disable=not logger.isEnabledFor(logging.INFO), total=len(dataloader), position=0) as pbar: pbar.set_description('E e:{} '.format(epoch)) for batch_idx, (sample_names, inputs, targets) in enumerate(dataloader): n_steps_this_epoch += 1 inputs = to_device(device, inputs) if "lab_phn" not in targets: targets = to_device(device, targets) output = model(inputs) output = detach_cpu(output) targets = detach_cpu(targets) #### Logging #### metrics_q.put((output, targets)) pbar.set_description('E e:{} '.format(epoch)) pbar.update() #### /Logging #### warned_label = False for output_label in output: if output_label in model.out_names: # squeeze that batch output[output_label] = output[ output_label].squeeze(1) # remove blank/padding 0th dim # if config["arch"]["framewise_labels"] == "shuffled_frames": out_save = output[output_label].data.cpu().numpy() # else: # raise NotImplementedError("TODO make sure the right dimension is taken") # out_save = output[output_label][:, :-1].data.cpu().numpy() if len(out_save.shape ) == 3 and out_save.shape[0] == 1: out_save = out_save.squeeze(0) if dataset.state.dataset_type != DatasetType.SEQUENTIAL_APPENDED_CONTEXT \ and dataset.state.dataset_type != DatasetType.SEQUENTIAL: raise NotImplementedError( "TODO rescaling with prior") # if config['dataset']['dataset_definition']['decoding']['normalize_posteriors']: # # read the config file # counts = config['dataset']['dataset_definition'] \ # ['data_info']['labels']['lab_phn']['lab_count'] # if out_save.shape[-1] == len(counts) - 1: # if not warned_size: # logger.info( # f"Counts length is {len(counts)} but output" # + f" has size {out_save.shape[-1]}." # + f" Assuming that counts is 1 indexed") # warned_size = True # counts = counts[1:] # # Normalize by output count # # if ctc: # # blank_scale = 1.0 # # # TODO try different blank_scales 4.0 5.0 6.0 7.0 # # counts[0] /= blank_scale # # # for i in range(1, 8): # # # counts[i] /= noise_scale #TODO try noise_scale for SIL SPN etc I guess # # # # prior = np.log(counts / np.sum(counts)) # # out_save = out_save - np.log(prior) # shape == NC assert len(out_save.shape) == 2 assert len(sample_names) == 1 writer.write_mat(output_label, out_save.squeeze(), sample_names[0]) else: if not warned_label: logger.debug( "Skipping saving forward for decoding for key {}" .format(output_label)) warned_label = True for _accumulated_test_metrics in accumulated_test_metrics_future_list: metrics_q.put(None) for _accumulated_test_metrics in accumulated_test_metrics_future_list: _accumulated_test_metrics = _accumulated_test_metrics.get() for metric, metric_value in _accumulated_test_metrics.items(): accumulated_test_metrics[metric] += metric_value # test_metrics = {metric: 0 for metric in metrics} # for metric in accumulated_test_metrics: # for metric, metric_value in metric.items(): # test_metrics[metric] += metric_value test_metrics = { metric: accumulated_test_metrics[metric] / len(dataloader) for metric in accumulated_test_metrics } if tensorboard_logger is not None: tensorboard_logger.set_step(epoch, 'eval') for metric, metric_value in test_metrics.items(): tensorboard_logger.add_scalar( metric, test_metrics[metric] / len(dataloader)) # decoding_results = [] #### DECODING #### # for out_lab in model.out_names: out_lab = model.out_names[0] # TODO query from model or sth # forward_data_lst = config['data_use']['test_with'] #TODO multiple forward sets # forward_data_lst = [config['dataset']['data_use']['test_with']] # forward_dec_outs = config['test'][out_lab]['require_decoding'] # for data in forward_data_lst: logger.debug('Decoding {} output {}'.format(test_with, out_lab)) if out_lab == 'out_cd': _label = 'lab_cd' elif out_lab == 'out_phn': _label = 'lab_phn' else: raise NotImplementedError(out_lab) lab_field = all_labs_dict[_label] out_folder = os.path.abspath(out_folder) out_dec_folder = '{}/decode_{}_{}'.format(out_folder, test_with, out_lab) # logits_test_clean_100_ep006_out_phn.ark files_dec_list = glob( f'{out_folder}/exp_files/logits_{test_with}_ep*_{out_lab}.ark') if lab_graph_dir is None: lab_graph_dir = os.path.abspath(lab_field['lab_graph']) if _label == 'lab_phn': decode_ctc(data=os.path.abspath(lab_field['lab_data_folder']), graphdir=lab_graph_dir, out_folder=out_dec_folder, featstrings=files_dec_list) elif _label == 'lab_cd': decode_ce(**decoding_info, alidir=os.path.abspath(lab_field['label_folder']), data=os.path.abspath(lab_field['lab_data_folder']), graphdir=lab_graph_dir, out_folder=out_dec_folder, featstrings=files_dec_list) else: raise ValueError(_label) decoding_results = best_wer(out_dec_folder, decoding_info['scoring_type']) logger.info(decoding_results) tensorboard_logger.add_text("WER results", str(decoding_results)) # TODO plotting curves return {'test_metrics': test_metrics, "decoding_results": decoding_results}
def _convert_from_kaldi_format(self, feature_dict, label_dict): main_feat = next(iter(feature_dict)) # download files try: os.makedirs(self.dataset_path) except OSError as e: if e.errno == errno.EEXIST: pass else: raise all_labels_loaded = self._load_labels(label_dict) with open(feature_dict[main_feat]["feature_lst_path"], "r") as f: lines = f.readlines() feat_list = lines random.shuffle(feat_list) file_chunks = list(split_chunks(feat_list, self.chunk_size)) self.max_len_per_chunk = [0] * len(file_chunks) self.min_len_per_chunk = [sys.maxsize] * len(file_chunks) self.samples_per_chunk = [] for chnk_id, file_chnk in tqdm(list(enumerate(file_chunks))): file_names = [feat.split(" ")[0] for feat in file_chnk] chnk_prefix = os.path.join(self.dataset_path, f"chunk_{chnk_id:04d}") features_loaded = {} for feature_name in feature_dict: chnk_scp = chnk_prefix + "feats.scp" with open(chnk_scp, "w") as f: f.writelines(file_chnk) features_loaded[feature_name] = load_features( chnk_scp, feature_dict[feature_name]["feature_opts"]) os.remove(chnk_scp) samples = {} for file in file_names: _continue = False for feature_name in feature_dict: if file not in features_loaded[feature_name]: logger.info(f"Skipping {file}, not in features") _continue = True break for label_name in label_dict: if file not in all_labels_loaded[label_name]: logger.info(f"Skipping {file}, not in labels") _continue = True break for feature_name in feature_dict: if type(self.max_sample_len) == int and \ len(features_loaded[feature_name][file]) > self.max_sample_len: logger.info( f"Skipping {file}, feature of size " + f"{len(features_loaded[feature_name][file])} too big " + f"( {self.max_sample_len} expected) ") _continue = True break if type(self.min_sample_len) == int and \ self.min_sample_len > len(features_loaded[feature_name][file]): logger.info( "Skipping {}, feature of size {} too small ( {} expected) " .format(file, len(features_loaded[feature_name][file]), self.max_sample_len)) _continue = True break if _continue: continue samples[file] = {"features": {}, "labels": {}} for feature_name in feature_dict: samples[file]["features"][feature_name] = features_loaded[ feature_name][file] for label_name in label_dict: samples[file]["labels"][label_name] = all_labels_loaded[ label_name][file] samples_list = list(samples.items()) mean = {} std = {} for feature_name in feature_dict: feat_concat = [] for file in file_names: feat_concat.append(features_loaded[feature_name][file]) feat_concat = np.concatenate(feat_concat) mean[feature_name] = np.mean(feat_concat, axis=0) std[feature_name] = np.std(feat_concat, axis=0) if not self.shuffle_frames: if self.split_files_max_sample_len: sample_splits = splits_by_seqlen( samples_list, self.split_files_max_sample_len, self.left_context, self.right_context) else: sample_splits = [(filename, self.left_context, len(sample_dict["features"][main_feat]) - self.right_context) for filename, sample_dict in samples_list] for sample_id, start_idx, end_idx in sample_splits: self.max_len_per_chunk[chnk_id] = (end_idx - start_idx) \ if (end_idx - start_idx) > self.max_len_per_chunk[chnk_id] else self.max_len_per_chunk[chnk_id] self.min_len_per_chunk[chnk_id] = (end_idx - start_idx) \ if (end_idx - start_idx) < self.min_len_per_chunk[chnk_id] else self.min_len_per_chunk[chnk_id] # sort sigs/labels: longest -> shortest sample_splits = sorted(sample_splits, key=lambda x: x[2] - x[1]) self.samples_per_chunk.append(len(sample_splits)) else: prev_index = 0 samples_idices = [] sample_ids = [] # sample_id, end_idx_total for sample_id, data in samples_list: prev_index += len( data['features'] [main_feat]) - self.left_context - self.right_context sample_ids.append(sample_id) samples_idices.append(prev_index) sample_splits = (sample_ids, samples_idices) self.samples_per_chunk.append(samples_idices[-1]) assert len(sample_splits) == self.samples_per_chunk[chnk_id] torch.save( { "samples": samples, "sample_splits": sample_splits, "means": mean, "std": std }, chnk_prefix + ".pyt") # TODO add warning when files get too big -> choose different chunk size self._write_info(feature_dict, label_dict) logger.info('Done extracting kaldi features!')
def valid_epoch_async_metrics(epoch, model, loss_fun, metrics, config, max_label_length, device, tensorboard_logger): """ Validate after training an epoch :return: A log that contains information about validation Note: The validation metrics in log must have the key 'val_metrics'. """ model.eval() valid_loss = 0 accumulated_valid_metrics = {metric: 0 for metric in metrics} valid_data = config['dataset']['data_use']['valid_with'] _all_feats = config['dataset']['dataset_definition']['datasets'][ valid_data]['features'] _all_labs = config['dataset']['dataset_definition']['datasets'][ valid_data]['labels'] dataset = get_dataset( config['training']['dataset_type'], config['exp']['data_cache_root'], f"{valid_data}_{config['exp']['name']}", {feat: _all_feats[feat] for feat in config['dataset']['features_use']}, {lab: _all_labs[lab] for lab in config['dataset']['labels_use']}, config['training']['batching']['max_seq_length_valid'], model.context_left, model.context_right, normalize_features=True, phoneme_dict=config['dataset']['dataset_definition']['phoneme_dict'], max_seq_len=config['training']['batching']['max_seq_length_valid'], max_label_length=max_label_length) dataloader = KaldiDataLoader( dataset, config['training']['batching']['batch_size_valid'], config["exp"]["n_gpu"] > 0, batch_ordering=model.batch_ordering) assert len(dataset) >= config['training']['batching']['batch_size_valid'], \ f"Length of valid dataset {len(dataset)} too small " \ + f"for batch_size of {config['training']['batching']['batch_size_valid']}" n_steps_this_epoch = 0 with Pool(os.cpu_count()) as pool: multip_process = Manager() metrics_q = multip_process.Queue(maxsize=os.cpu_count()) # accumulated_valid_metrics_future_list = pool.apply_async(metrics_accumulator, (metrics_q, metrics)) accumulated_valid_metrics_future_list = [ pool.apply_async(metrics_accumulator, (metrics_q, metrics)) for _ in range(os.cpu_count()) ] with tqdm(disable=not logger.isEnabledFor(logging.INFO), total=len(dataloader)) as pbar: pbar.set_description('V e:{} l: {} '.format(epoch, '-')) for batch_idx, (_, inputs, targets) in enumerate(dataloader): n_steps_this_epoch += 1 inputs = to_device(device, inputs) if "lab_phn" not in targets: targets = to_device(device, targets) output = model(inputs) loss = loss_fun(output, targets) output = detach_cpu(output) targets = detach_cpu(targets) loss = detach_cpu(loss) #### Logging #### valid_loss += loss["loss_final"].item() metrics_q.put((output, targets)) # _valid_metrics = eval_metrics((output, targets), metrics) # for metric, metric_value in _valid_metrics.items(): # accumulated_valid_metrics[metric] += metric_value pbar.set_description('V e:{} l: {:.4f} '.format( epoch, loss["loss_final"].item())) pbar.update() #### /Logging #### for _accumulated_valid_metrics in accumulated_valid_metrics_future_list: metrics_q.put(None) for _accumulated_valid_metrics in accumulated_valid_metrics_future_list: _accumulated_valid_metrics = _accumulated_valid_metrics.get() for metric, metric_value in _accumulated_valid_metrics.items(): accumulated_valid_metrics[metric] += metric_value tensorboard_logger.set_step(epoch, 'valid') tensorboard_logger.add_scalar('valid_loss', valid_loss / n_steps_this_epoch) logger.info(f'valid_loss: {valid_loss / n_steps_this_epoch}') for metric in accumulated_valid_metrics: tensorboard_logger.add_scalar( metric, accumulated_valid_metrics[metric] / n_steps_this_epoch) logger.info( f'{metric}: {accumulated_valid_metrics[metric] / n_steps_this_epoch}' ) return { 'valid_loss': valid_loss / n_steps_this_epoch, 'valid_metrics': { metric: accumulated_valid_metrics[metric] / n_steps_this_epoch for metric in accumulated_valid_metrics } }
def save_checkpoint( epoch, global_step, model, optimizers, lr_schedulers, seq_len_scheduler, config, checkpoint_dir, # monitor_best=None, dataset_sampler_state=None, save_best=None): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ assert dataset_sampler_state != save_best, "save_best is only done at the end of an epoch" # TODO figure out why shutil.disk_usage gives different result to df # available_disk_space_in_gb = shutil.disk_usage(checkpoint_dir).free * 1e-9 available_disk_space_in_gb = run_shell(f"df -h {checkpoint_dir}") available_disk_space_in_gb = int( available_disk_space_in_gb.split("\n")[1].split(" ")[13][:-1]) assert available_disk_space_in_gb > 5, \ f"available_disk_space_in_gb of {available_disk_space_in_gb} is lower than 5GB" \ + f"Aborting to try to save in order to not corrupt the model files" torch_rng_state, python_rng_state, numpy_rng_state = get_rng_state() state = { 'epoch': epoch, 'global_step': global_step, 'state_dict': model.state_dict(), 'optimizers': { opti_name: optimizers[opti_name].state_dict() for opti_name in optimizers }, 'lr_schedulers': { lr_sched_name: lr_schedulers[lr_sched_name].state_dict() for lr_sched_name in lr_schedulers }, 'seq_len_scheduler': seq_len_scheduler, 'dataset_sampler_state': dataset_sampler_state, # 'monitor_best': monitor_best, 'config': config, 'torch_rng_state': torch_rng_state, 'python_rng_state': python_rng_state, 'numpy_rng_state': numpy_rng_state, } if dataset_sampler_state is not None: # Intermediate save during training epoch all_previous_checkpoints = glob( os.path.join(checkpoint_dir, 'checkpoint_e*_gs*.pth')) checkpoint_name = f'checkpoint_e{epoch}_gs{global_step}.pth' filename = os.path.join(checkpoint_dir, checkpoint_name) torch.save(state, filename) logger.info(f"Saved checkpoint: {filename}") for old_checkpoint in all_previous_checkpoints: if os.path.exists(old_checkpoint): os.remove(old_checkpoint) logger.info(f"Removed old checkpoint: {old_checkpoint} ") else: checkpoint_name = f'checkpoint_e{epoch}.pth' filename = os.path.join(checkpoint_dir, checkpoint_name) torch.save(state, filename) logger.info(f"Saved checkpoint: {filename}") if epoch >= 3: filename_prev = os.path.join(checkpoint_dir, f'checkpoint_e{epoch - 3}.pth') if os.path.exists(filename_prev): os.remove(filename_prev) logger.info(f"Removed old checkpoint: {filename_prev} ") if save_best is not None and save_best: checkpoint_name = f'checkpoint_best.pth' best_path = os.path.join(checkpoint_dir, checkpoint_name) torch.save(state, best_path) logger.info(f"Saved current best: {checkpoint_name}") # available_disk_space_in_gb = shutil.disk_usage(checkpoint_dir).free * 1e-9 available_disk_space_in_gb = run_shell(f"df -h {checkpoint_dir}") available_disk_space_in_gb = int( available_disk_space_in_gb.split("\n")[1].split(" ")[13][:-1]) assert available_disk_space_in_gb > 5, \ f"available_disk_space_in_gb of {available_disk_space_in_gb} is lower than 5GB" \ + f"Aborting since next checkpoint save probably fails because of too little space -> no wasted training compute"