예제 #1
0
def get_mapped_symbol_weights(model_symbols: SymbolIdDict,
                              trained_weights: Tensor,
                              trained_symbols: SymbolIdDict,
                              custom_mapping: Optional[SymbolsMap],
                              hparams: HParams, logger: Logger) -> Tensor:
    symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols)
    if symbols_match_not_model:
        logger.exception(
            f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})."
        )
        raise Exception()

    if custom_mapping is None:
        symbols_map = SymbolsMap.from_intersection(
            map_to=model_symbols.get_all_symbols(),
            map_from=trained_symbols.get_all_symbols(),
        )
    else:
        symbols_map = custom_mapping
        symbols_map.remove_unknown_symbols(
            known_to_symbol=model_symbols.get_all_symbols(),
            known_from_symbols=trained_symbols.get_all_symbols())

    # Remove all empty mappings
    symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping()
    symbols_map.pop_batch(symbols_wo_mapping)

    symbols_id_map = symbols_map.convert_to_symbols_ids_map(
        to_symbols=model_symbols,
        from_symbols=trained_symbols,
    )

    model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map(
        symbols_id_map,
        hparams.n_accents,
        n_symbols=hparams.n_symbols,
        accents_use_own_symbols=hparams.accents_use_own_symbols)

    model_weights = get_symbol_weights(hparams)

    map_weights(model_symbols_id_map=model_symbols_id_map,
                model_weights=model_weights,
                trained_weights=trained_weights,
                logger=logger)

    not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys()
    no_mapping = symbols_wo_mapping | not_existing_symbols
    if len(no_mapping) > 0:
        logger.warning(f"Following symbols were not mapped: {no_mapping}")
    else:
        logger.info("All symbols were mapped.")

    return model_weights
예제 #2
0
def update_symbols(data: MergedDataset, symbols: SymbolIdDict) -> SymbolIdDict:
    new_symbols: Set[str] = {
        x
        for y in data.items()
        for x in symbols.get_symbols(y.serialized_symbol_ids)
    }
    new_symbol_ids = SymbolIdDict.init_from_symbols_with_pad(
        new_symbols, pad_symbol=DEFAULT_PADDING_SYMBOL)
    if new_symbol_ids.get_all_symbols() != symbols.get_all_symbols():
        for entry in data.items():
            original_symbols = symbols.get_symbols(entry.serialized_symbol_ids)
            entry.serialized_symbol_ids = new_symbol_ids.get_serialized_ids(
                original_symbols)
    return new_symbol_ids
예제 #3
0
def get_shard_size(symbols: SymbolIdDict) -> int:
  all_symbols = symbols.get_all_symbols()
  count = len(all_symbols)
  if DEFAULT_PADDING_SYMBOL in all_symbols:
    count -= 1
  return count