def test_merge_prepared_data(self): prep_list = [ (PreparedDataList([ PreparedData(0, 1, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""), ]), SymbolIdDict({ 0: (0, "a"), 1: (0, "b"), 2: (0, "c"), })), (PreparedDataList([ PreparedData(0, 2, "", "", "", 0, "0,1,2", 0, 0, "", 0, ""), ]), SymbolIdDict({ 0: (0, "b"), 1: (0, "a"), 2: (0, "d"), })), ] res, conv = merge_prepared_data(prep_list) self.assertEqual(6, len(conv)) self.assertEqual("todo", conv.get_symbol(0)) self.assertEqual("todo", conv.get_symbol(1)) self.assertEqual("a", conv.get_symbol(2)) self.assertEqual("b", conv.get_symbol(3)) self.assertEqual("c", conv.get_symbol(4)) self.assertEqual("d", conv.get_symbol(5)) self.assertEqual(2, len(res)) self.assertEqual(0, res[0].i) self.assertEqual(1, res[1].i) self.assertEqual(1, res[0].entry_id) self.assertEqual(2, res[1].entry_id) self.assertEqual("2,3,4", res[0].serialized_symbol_ids) self.assertEqual("3,2,5", res[1].serialized_symbol_ids)
def update_symbols_and_text(sentences: SentenceList, sents_new_symbols: List[List[str]]): symbols = SymbolIdDict.init_from_symbols(get_unique_items([x for x in sents_new_symbols])) for sentence, new_symbols in zip(sentences.items(), sents_new_symbols): sentence.serialized_symbols = symbols.get_serialized_ids(new_symbols) sentence.text = SymbolIdDict.symbols_to_text(new_symbols) assert len(sentence.get_symbol_ids()) == len(new_symbols) assert len(sentence.get_accent_ids()) == len(new_symbols) return symbols, sentences
def replace_unknown_symbols(self, model_symbols: SymbolIdDict, logger: Logger) -> bool: unknown_symbols_exist = False for sentence in self.items(): if model_symbols.has_unknown_symbols(sentence.symbols): sentence.symbols = model_symbols.replace_unknown_symbols_with_pad(sentence.symbols) text = SymbolIdDict.symbols_to_text(sentence.symbols) logger.info(f"Sentence {sentence.sent_id} contains unknown symbols: {text}") unknown_symbols_exist = True assert len(sentence.symbols) == len(sentence.accents) return unknown_symbols_exist
def get_mapped_symbol_weights(model_symbols: SymbolIdDict, trained_weights: Tensor, trained_symbols: SymbolIdDict, custom_mapping: Optional[SymbolsMap], hparams: HParams, logger: Logger) -> Tensor: symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols) if symbols_match_not_model: logger.exception( f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})." ) raise Exception() if custom_mapping is None: symbols_map = SymbolsMap.from_intersection( map_to=model_symbols.get_all_symbols(), map_from=trained_symbols.get_all_symbols(), ) else: symbols_map = custom_mapping symbols_map.remove_unknown_symbols( known_to_symbol=model_symbols.get_all_symbols(), known_from_symbols=trained_symbols.get_all_symbols()) # Remove all empty mappings symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping() symbols_map.pop_batch(symbols_wo_mapping) symbols_id_map = symbols_map.convert_to_symbols_ids_map( to_symbols=model_symbols, from_symbols=trained_symbols, ) model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map( symbols_id_map, hparams.n_accents, n_symbols=hparams.n_symbols, accents_use_own_symbols=hparams.accents_use_own_symbols) model_weights = get_symbol_weights(hparams) map_weights(model_symbols_id_map=model_symbols_id_map, model_weights=model_weights, trained_weights=trained_weights, logger=logger) not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys() no_mapping = symbols_wo_mapping | not_existing_symbols if len(no_mapping) > 0: logger.warning(f"Following symbols were not mapped: {no_mapping}") else: logger.info("All symbols were mapped.") return model_weights
def convert_to_symbols_ids_map(self, to_symbols: SymbolIdDict, from_symbols: SymbolIdDict) -> OrderedDictType[int, int]: result: OrderedDictType[int, int] = OrderedDict() for map_to_symbol, map_from_symbol in self.items(): assert to_symbols.symbol_exists(map_to_symbol) assert from_symbols.symbol_exists(map_from_symbol) map_from_symbol_id = from_symbols.get_id(map_from_symbol) map_to_symbol_id = to_symbols.get_id(map_to_symbol) result[map_to_symbol_id] = map_from_symbol_id # logger.info( # f"Mapped symbol '{map_from_symbol}' ({map_from_symbol_id}) to symbol '{map_to_symbol}' ({map_to_symbol_id})") return result
def test_get_symbols_id_mapping_without_map(self): from_symbols = {"b", "c"} to_symbols = {"a", "b"} from_conv = SymbolIdDict.init_from_symbols(from_symbols) to_conv = SymbolIdDict.init_from_symbols(to_symbols) mapping = SymbolsMap.from_intersection(from_symbols, to_symbols) mapping.update_existing_to_mappings({"a": "c"}) symbols_id_map = mapping.convert_to_symbols_ids_map( to_symbols=to_conv, from_symbols=from_conv, ) self.assertEqual(symbols_id_map[from_conv.get_id("b")], to_conv.get_id("b")) self.assertEqual(symbols_id_map[from_conv.get_id("c")], to_conv.get_id("a")) self.assertEqual(2, len(symbols_id_map))
def get_formatted(self, symbol_id_dict: SymbolIdDict, accent_id_dict: AccentsDict, pairs_per_line=170, space_length=0): return get_formatted_core( sent_id=self.sent_id, symbols=symbol_id_dict.get_symbols(self.serialized_symbols), accent_ids=self.get_accent_ids(), accent_id_dict=accent_id_dict, space_length=space_length, max_pairs_per_line=pairs_per_line )
def _prepare_data( processed_data: List[Tuple[int, List[str], List[int], Language]] ) -> Tuple[TextDataList, SymbolIdDict, AccentsDict, SymbolsDict]: result = TextDataList() symbol_counter = get_counter([x[1] for x in processed_data]) symbols_dict = SymbolsDict.fromcounter(symbol_counter) conv = SymbolIdDict.init_from_symbols(set(symbols_dict.keys())) for entry_id, symbols, accent_ids, lang in processed_data: assert len(accent_ids) == len(symbols) text = SymbolIdDict.symbols_to_text(symbols) serialized_symbol_ids = conv.get_serialized_ids(symbols) serialized_accent_ids = serialize_list(accent_ids) data = TextData(entry_id, text, serialized_symbol_ids, serialized_accent_ids, lang) result.append(data) return result, conv, symbols_dict
def test_sims_to_csv(self): emb = torch.ones(size=(2, 3)) torch.nn.init.zeros_(emb[0]) sims = get_similarities(emb) symbols = SymbolIdDict.init_from_symbols({"2", "1"}) res = sims_to_csv(sims, symbols) self.assertEqual(2, len(res.index)) self.assertListEqual(['2', '<=>', '1', 'nan'], list(res.values[0])) self.assertListEqual(['1', '<=>', '2', 'nan'], list(res.values[1]))
def test_plot_embeddings(self): emb = torch.ones(size=(2, 3)) symbols = SymbolIdDict.init_from_symbols({}) text, plot2d, plot3d = plot_embeddings(symbols, emb, logging.getLogger()) self.assertEqual(2, len(text.index)) self.assertListEqual(['a', '<=>', 'b', '1.00'], list(text.values[0])) self.assertListEqual(['b', '<=>', 'a', '1.00'], list(text.values[1])) self.assertEqual("2D-Embeddings", plot2d.layout.title.text) self.assertEqual("3D-Embeddings", plot3d.layout.title.text)
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] result = SentenceList() new_sent_id = 0 for sentence in sentences.items(): symbols = text_symbols.get_symbols(sentence.serialized_symbols) accent_ids = deserialize_list(sentence.serialized_accents) mapped_symbols = symbols_map.apply_to_symbols(symbols) text = SymbolIdDict.symbols_to_text(mapped_symbols) # a resulting empty text would make no problems sents = split_sentences(text, sentence.lang) for new_sent_text in sents: new_symbols = text_to_symbols( new_sent_text, lang=sentence.lang, ignore_tones=False, ignore_arcs=ignore_arcs ) if len(accent_ids) > 0: new_accent_ids = [accent_ids[0]] * len(new_symbols) else: new_accent_ids = [] assert len(new_accent_ids) == len(new_symbols) new_sent_id += 1 tmp = Sentence( sent_id=new_sent_id, text=new_sent_text, lang=sentence.lang, serialized_accents=serialize_list(new_accent_ids), serialized_symbols="" ) sents_new_symbols.append(new_symbols) assert len(tmp.get_accent_ids()) == len(new_symbols) result.append(tmp) return update_symbols_and_text(result, sents_new_symbols)
def sims_to_csv(sims: Dict[int, List[Tuple[int, float]]], symbols: SymbolIdDict) -> pd.DataFrame: lines = [] assert len(sims) == len(symbols) for symbol_id, similarities in sims.items(): sims = [f"{symbols.get_symbol(symbol_id)}", "<=>"] for other_symbol_id, similarity in similarities: sims.append(symbols.get_symbol(other_symbol_id)) sims.append(f"{similarity:.2f}") lines.append(sims) df = pd.DataFrame(lines) return df
def from_sentences(cls, sentences: SentenceList, accents: AccentsDict, symbols: SymbolIdDict): res = cls() for sentence in sentences.items(): infer_sent = InferSentence( sent_id=sentence.sent_id, symbols=symbols.get_symbols(sentence.serialized_symbols), accents=accents.get_accents(sentence.serialized_accents) ) assert len(infer_sent.symbols) == len(infer_sent.accents) res.append(infer_sent) return res
def from_instances(cls, model: Tacotron2, optimizer: Adam, hparams: HParams, iteration: int, symbols: SymbolIdDict, accents: AccentsDict, speakers: SpeakersDict): result = cls(state_dict=model.state_dict(), optimizer=optimizer.state_dict(), learning_rate=hparams.learning_rate, iteration=iteration, hparams=asdict(hparams), symbols=symbols.raw(), accents=accents.raw(), speakers=speakers.raw()) return result
def add_text(text: str, lang: Language) -> Tuple[SymbolIdDict, SentenceList]: res = SentenceList() sents = split_sentences(text, lang) default_accent_id = 0 sents_symbols: List[List[str]] = [text_to_symbols( sent, lang=lang, ignore_tones=False, ignore_arcs=False ) for sent in sents] symbols = SymbolIdDict.init_from_symbols(get_unique_items(sents_symbols)) for i, sent_symbols in enumerate(sents_symbols): sentence = Sentence( sent_id=i + 1, lang=lang, serialized_symbols=symbols.get_serialized_ids(sent_symbols), serialized_accents=serialize_list([default_accent_id] * len(sent_symbols)), text=SymbolIdDict.symbols_to_text(sent_symbols), ) res.append(sentence) return symbols, res
def preprocess( data: DsDataList, symbol_ids: SymbolIdDict ) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]: processed_data: List[Tuple[int, List[str], List[int], Language]] = [] values: DsData for values in tqdm(data): symbols: List[str] = symbol_ids.get_symbols( deserialize_list(values.serialized_symbols)) accents: List[int] = deserialize_list(values.serialized_accents) processed_data.append((values.entry_id, symbols, accents, values.lang)) return _prepare_data(processed_data)
def plot_embeddings(symbols: SymbolIdDict, emb: torch.Tensor, logger: Logger) -> Tuple[pd.DataFrame, go.Figure, go.Figure]: assert emb.shape[0] == len(symbols) logger.info(f"Emb size {emb.shape}") logger.info(f"Sym len {len(symbols)}") sims = get_similarities(emb.numpy()) df = sims_to_csv(sims, symbols) all_symbols = symbols.get_all_symbols() emb_normed = norm2emb(emb) fig_2d = emb_plot_2d(emb_normed, all_symbols) fig_3d = emb_plot_3d(emb_normed, all_symbols) return df, fig_2d, fig_3d
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict) -> Tuple[SymbolIdDict, SentenceList]: # Maybe add info if something was unknown sents_new_symbols = [] for sentence in sentences.items(): new_symbols, new_accent_ids = symbols_normalize( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents) ) # TODO: check if new sentences resulted and then split them. sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def sents_accent_template(sentences: SentenceList, text_symbols: SymbolIdDict, accent_ids: AccentsDict) -> AccentedSymbolList: res = AccentedSymbolList() for i, sent in enumerate(sentences.items()): symbols = text_symbols.get_symbols(sent.serialized_symbols) accents = accent_ids.get_accents(sent.serialized_accents) for j, symbol_accent in enumerate(zip(symbols, accents)): symbol, accent = symbol_accent accented_symbol = AccentedSymbol( position=f"{i}-{j}", symbol=symbol, accent=accent ) res.append(accented_symbol) return res
def make_common_symbol_ids(self) -> SymbolIdDict: all_symbols: Set[str] = set() for ds in self.data: all_symbols |= ds.symbol_ids.get_all_symbols() new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad(all_symbols) for ds in self.data: for entry in ds.data.items(): original_symbols = ds.symbol_ids.get_symbols( entry.serialized_symbol_ids) entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids( original_symbols) ds.symbol_ids = new_symbol_ids return new_symbol_ids
def _get_ds_data(l: PreDataList, speakers_dict: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict) -> DsDataList: result = [ DsData(entry_id=i, basename=values.name, speaker_name=values.speaker_name, speaker_id=speakers_dict[values.speaker_name], text=values.text, serialized_symbols=symbols.get_serialized_ids(values.symbols), serialized_accents=accents.get_serialized_ids(values.accents), wav_path=values.wav_path, lang=values.lang, gender=values.gender) for i, values in enumerate(l.items()) ] return DsDataList(result)
def symbols_convert_to_ipa(symbols: List[str], lang: Language, accent_ids: List[str], ignore_tones: bool, ignore_arcs: bool) -> Tuple[List[str], List[int]]: assert len(symbols) == len(accent_ids) # Note: do also for ipa symbols to have possibility to remove arcs and tones orig_text = SymbolIdDict.symbols_to_text(symbols) ipa = convert_to_ipa(orig_text, lang) new_symbols: List[str] = text_to_symbols(ipa, Language.IPA, ignore_tones, ignore_arcs) if len(accent_ids) > 0: new_accent_ids = [accent_ids[0]] * len(new_symbols) else: new_accent_ids = [] assert len(new_symbols) == len(new_accent_ids) return new_symbols, new_accent_ids
def test_get_similarities_is_sorted(self): symbols = SymbolIdDict.init_from_symbols({"a", "b", "c"}) emb = np.zeros(shape=(3, 3)) emb[symbols.get_id("a")] = [0.5, 1.0, 0] emb[symbols.get_id("b")] = [1.0, 0.6, 0] emb[symbols.get_id("c")] = [1.0, 0.5, 0] sims = get_similarities(emb) self.assertEqual(3, len(sims)) self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("a")][0][0]) self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("a")][1][0]) self.assertEqual(symbols.get_id("b"), sims[symbols.get_id("c")][0][0]) self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("c")][1][0]) self.assertEqual(symbols.get_id("c"), sims[symbols.get_id("b")][0][0]) self.assertEqual(symbols.get_id("a"), sims[symbols.get_id("b")][1][0])
def symbols_normalize(symbols: List[str], lang: Language, accent_ids: List[str]) -> Tuple[List[str], List[int]]: assert len(symbols) == len(accent_ids) orig_text = SymbolIdDict.symbols_to_text(symbols) text = normalize(orig_text, lang) new_symbols: List[str] = text_to_symbols(text, lang) if lang != Language.IPA: if len(accent_ids) > 0: new_accent_ids = [accent_ids[0]] * len(new_symbols) else: new_accent_ids = [] else: # because no replacing was done in ipa normalization # maybe support remove whitespace new_accent_ids = accent_ids assert len(new_symbols) == len(new_accent_ids) return new_symbols, new_accent_ids
def convert_to_ipa( data: TextDataList, symbol_converter: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]: processed_data: List[Tuple[int, List[str], List[int], Language]] = [] values: TextData for values in tqdm(data.items()): new_symbols, new_accent_ids = symbols_convert_to_ipa( symbols=symbol_converter.get_symbols(values.serialized_symbol_ids), lang=values.lang, accent_ids=deserialize_list(values.serialized_accent_ids), ignore_arcs=ignore_arcs, ignore_tones=ignore_tones) processed_data.append( (values.entry_id, new_symbols, new_accent_ids, Language.IPA)) return _prepare_data(processed_data)
def normalize( data: TextDataList, symbol_converter: SymbolIdDict ) -> Tuple[TextDataList, SymbolIdDict, SymbolsDict]: processed_data: List[Tuple[int, List[str], List[int], Language]] = [] values: TextData for values in tqdm(data): new_symbols, new_accent_ids = symbols_normalize( symbols=symbol_converter.get_symbols(values.serialized_symbol_ids), lang=values.lang, accent_ids=deserialize_list(values.serialized_accent_ids), ) processed_data.append( (values.entry_id, new_symbols, new_accent_ids, values.lang)) return _prepare_data(processed_data)
def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] for sentence in sentences.items(True): new_symbols, new_accent_ids = symbols_convert_to_ipa( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents), ignore_arcs=ignore_arcs, ignore_tones=ignore_tones ) assert len(new_symbols) == len(new_accent_ids) sentence.lang = Language.IPA sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) assert len(sentence.get_accent_ids()) == len(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def convert_v1_to_v2_model(old_model_path: str, custom_hparams: Optional[Dict[str, str]], speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict): checkpoint_dict = torch.load(old_model_path, map_location='cpu') hparams = HParams(n_speakers=len(speakers), n_accents=len(accents), n_symbols=len(symbols)) hparams = overwrite_custom_hparams(hparams, custom_hparams) chp = CheckpointTacotron(state_dict=checkpoint_dict["state_dict"], optimizer=checkpoint_dict["optimizer"], learning_rate=checkpoint_dict["learning_rate"], iteration=checkpoint_dict["iteration"] + 1, hparams=asdict(hparams), speakers=speakers.raw(), symbols=symbols.raw(), accents=accents.raw()) new_model_path = f"{old_model_path}_{get_pytorch_filename(chp.iteration)}" chp.save(new_model_path, logging.getLogger())
def _train(custom_hparams: Optional[Dict[str, str]], taco_logger: Tacotron2Logger, trainset: PreparedDataList, valset: PreparedDataList, save_callback: Any, speakers: SpeakersDict, accents: AccentsDict, symbols: SymbolIdDict, checkpoint: Optional[CheckpointTacotron], warm_model: Optional[CheckpointTacotron], weights_checkpoint: Optional[CheckpointTacotron], weights_map: SymbolsMap, map_from_speaker_name: Optional[str], logger: Logger, checkpoint_logger: Logger): """Training and validation logging results to tensorboard and stdout Params ------ output_directory (string): directory to save checkpoints log_directory (string) directory to save tensorboard logs checkpoint_path(string): checkpoint path n_gpus (int): number of gpus rank (int): rank of current gpu hparams (object): comma separated list of "name=value" pairs. """ complete_start = time.time() if checkpoint is not None: hparams = checkpoint.get_hparams(logger) else: hparams = HParams( n_accents=len(accents), n_speakers=len(speakers), n_symbols=len(symbols) ) # is it problematic to change the batch size on a trained model? hparams = overwrite_custom_hparams(hparams, custom_hparams) assert hparams.n_accents > 0 assert hparams.n_speakers > 0 assert hparams.n_symbols > 0 if hparams.use_saved_learning_rate and checkpoint is not None: hparams.learning_rate = checkpoint.learning_rate log_hparams(hparams, logger) init_torch(hparams) model, optimizer = load_model_and_optimizer( hparams=hparams, checkpoint=checkpoint, logger=logger, ) iteration = get_iteration(checkpoint) if checkpoint is None: if warm_model is not None: logger.info("Loading states from pretrained model...") warm_start_model(model, warm_model, hparams, logger) if weights_checkpoint is not None: logger.info("Mapping symbol embeddings...") pretrained_symbol_weights = get_mapped_symbol_weights( model_symbols=symbols, trained_weights=weights_checkpoint.get_symbol_embedding_weights(), trained_symbols=weights_checkpoint.get_symbols(), custom_mapping=weights_map, hparams=hparams, logger=logger ) update_weights(model.embedding, pretrained_symbol_weights) map_speaker_weights = map_from_speaker_name is not None if map_speaker_weights: logger.info("Mapping speaker embeddings...") pretrained_speaker_weights = get_mapped_speaker_weights( model_speakers=speakers, trained_weights=weights_checkpoint.get_speaker_embedding_weights(), trained_speaker=weights_checkpoint.get_speakers(), map_from_speaker_name=map_from_speaker_name, hparams=hparams, logger=logger, ) update_weights(model.speakers_embedding, pretrained_speaker_weights) log_symbol_weights(model, logger) collate_fn = SymbolsMelCollate( n_frames_per_step=hparams.n_frames_per_step, padding_symbol_id=symbols.get_id(PADDING_SYMBOL), padding_accent_id=accents.get_id(PADDING_ACCENT) ) val_loader = prepare_valloader(hparams, collate_fn, valset, logger) train_loader = prepare_trainloader(hparams, collate_fn, trainset, logger) batch_iterations = len(train_loader) enough_traindata = batch_iterations > 0 if not enough_traindata: msg = "Not enough trainingdata." logger.error(msg) raise Exception(msg) save_it_settings = SaveIterationSettings( epochs=hparams.epochs, batch_iterations=batch_iterations, save_first_iteration=True, save_last_iteration=True, iters_per_checkpoint=hparams.iters_per_checkpoint, epochs_per_checkpoint=hparams.epochs_per_checkpoint ) criterion = Tacotron2Loss() batch_durations: List[float] = [] train_start = time.perf_counter() start = train_start model.train() continue_epoch = get_continue_epoch(iteration, batch_iterations) for epoch in range(continue_epoch, hparams.epochs): next_batch_iteration = get_continue_batch_iteration(iteration, batch_iterations) skip_bar = None if next_batch_iteration > 0: logger.debug(f"Current batch is {next_batch_iteration} of {batch_iterations}") logger.debug("Skipping batches...") skip_bar = tqdm(total=next_batch_iteration) for batch_iteration, batch in enumerate(train_loader): need_to_skip_batch = skip_batch( batch_iteration=batch_iteration, continue_batch_iteration=next_batch_iteration ) if need_to_skip_batch: assert skip_bar is not None skip_bar.update(1) #debug_logger.debug(f"Skipped batch {batch_iteration + 1}/{next_batch_iteration + 1}.") continue # debug_logger.debug(f"Current batch: {batch[0][0]}") # update_learning_rate_optimizer(optimizer, hparams.learning_rate) model.zero_grad() x, y = parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = loss.item() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( parameters=model.parameters(), max_norm=hparams.grad_clip_thresh ) optimizer.step() iteration += 1 end = time.perf_counter() duration = end - start start = end batch_durations.append(duration) avg_batch_dur = np.mean(batch_durations) avg_epoch_dur = avg_batch_dur * batch_iterations remaining_its = hparams.epochs * batch_iterations - iteration estimated_remaining_duration = avg_batch_dur * remaining_its next_it = get_next_save_it(iteration, save_it_settings) next_checkpoint_save_time = 0 if next_it is not None: next_checkpoint_save_time = (next_it - iteration) * avg_batch_dur logger.info(" | ".join([ f"Epoch: {get_formatted_current_total(epoch + 1, hparams.epochs)}", f"It.: {get_formatted_current_total(batch_iteration + 1, batch_iterations)}", f"Tot. it.: {get_formatted_current_total(iteration, hparams.epochs * batch_iterations)} ({iteration / (hparams.epochs * batch_iterations) * 100:.2f}%)", f"Loss: {reduced_loss:.6f}", f"Grad norm: {grad_norm:.6f}", #f"Dur.: {duration:.2f}s/it", f"Avg. dur.: {avg_batch_dur:.2f}s/it & {avg_epoch_dur / 60:.0f}min/epoch", f"Tot. dur.: {(time.perf_counter() - train_start) / 60 / 60:.2f}h/{estimated_remaining_duration / 60 / 60:.0f}h ({estimated_remaining_duration / 60 / 60 / 24:.1f}days)", f"Next checkpoint: {next_checkpoint_save_time / 60:.0f}min", ])) taco_logger.log_training(reduced_loss, grad_norm, hparams.learning_rate, duration, iteration) save_it = check_save_it(epoch, iteration, save_it_settings) if save_it: checkpoint = CheckpointTacotron.from_instances( model=model, optimizer=optimizer, hparams=hparams, iteration=iteration, symbols=symbols, accents=accents, speakers=speakers ) save_callback(checkpoint) valloss = validate(model, criterion, val_loader, iteration, taco_logger, logger) # if rank == 0: log_checkpoint_score(iteration, grad_norm, reduced_loss, valloss, epoch, batch_iteration, checkpoint_logger) duration_s = time.time() - complete_start logger.info(f'Finished training. Total duration: {duration_s / 60:.2f}min')
def save_prep_symbol_converter(prep_dir: str, data: SymbolIdDict): path = os.path.join(prep_dir, _prepared_symbols_json) data.save(path)