def get_mapped_symbol_weights(model_symbols: SymbolIdDict, trained_weights: Tensor, trained_symbols: SymbolIdDict, custom_mapping: Optional[SymbolsMap], hparams: HParams, logger: Logger) -> Tensor: symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols) if symbols_match_not_model: logger.exception( f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})." ) raise Exception() if custom_mapping is None: symbols_map = SymbolsMap.from_intersection( map_to=model_symbols.get_all_symbols(), map_from=trained_symbols.get_all_symbols(), ) else: symbols_map = custom_mapping symbols_map.remove_unknown_symbols( known_to_symbol=model_symbols.get_all_symbols(), known_from_symbols=trained_symbols.get_all_symbols()) # Remove all empty mappings symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping() symbols_map.pop_batch(symbols_wo_mapping) symbols_id_map = symbols_map.convert_to_symbols_ids_map( to_symbols=model_symbols, from_symbols=trained_symbols, ) model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map( symbols_id_map, hparams.n_accents, n_symbols=hparams.n_symbols, accents_use_own_symbols=hparams.accents_use_own_symbols) model_weights = get_symbol_weights(hparams) map_weights(model_symbols_id_map=model_symbols_id_map, model_weights=model_weights, trained_weights=trained_weights, logger=logger) not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys() no_mapping = symbols_wo_mapping | not_existing_symbols if len(no_mapping) > 0: logger.warning(f"Following symbols were not mapped: {no_mapping}") else: logger.info("All symbols were mapped.") return model_weights
def test_get_symbols_id_mapping_without_map(self): from_symbols = {"b", "c"} to_symbols = {"a", "b"} from_conv = SymbolIdDict.init_from_symbols(from_symbols) to_conv = SymbolIdDict.init_from_symbols(to_symbols) mapping = SymbolsMap.from_intersection(from_symbols, to_symbols) mapping.update_existing_to_mappings({"a": "c"}) symbols_id_map = mapping.convert_to_symbols_ids_map( to_symbols=to_conv, from_symbols=from_conv, ) self.assertEqual(symbols_id_map[from_conv.get_id("b")], to_conv.get_id("b")) self.assertEqual(symbols_id_map[from_conv.get_id("c")], to_conv.get_id("a")) self.assertEqual(2, len(symbols_id_map))
def test_from_two_sets(self): m = SymbolsMap.from_intersection({"a", "b"}, {"b", "c"}) self.assertEqual(2, len(m)) self.assertEqual("b", m["b"]) self.assertEqual("", m["c"])