コード例 #1
0
def get_mapped_symbol_weights(model_symbols: SymbolIdDict,
                              trained_weights: Tensor,
                              trained_symbols: SymbolIdDict,
                              custom_mapping: Optional[SymbolsMap],
                              hparams: HParams, logger: Logger) -> Tensor:
    symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols)
    if symbols_match_not_model:
        logger.exception(
            f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})."
        )
        raise Exception()

    if custom_mapping is None:
        symbols_map = SymbolsMap.from_intersection(
            map_to=model_symbols.get_all_symbols(),
            map_from=trained_symbols.get_all_symbols(),
        )
    else:
        symbols_map = custom_mapping
        symbols_map.remove_unknown_symbols(
            known_to_symbol=model_symbols.get_all_symbols(),
            known_from_symbols=trained_symbols.get_all_symbols())

    # Remove all empty mappings
    symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping()
    symbols_map.pop_batch(symbols_wo_mapping)

    symbols_id_map = symbols_map.convert_to_symbols_ids_map(
        to_symbols=model_symbols,
        from_symbols=trained_symbols,
    )

    model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map(
        symbols_id_map,
        hparams.n_accents,
        n_symbols=hparams.n_symbols,
        accents_use_own_symbols=hparams.accents_use_own_symbols)

    model_weights = get_symbol_weights(hparams)

    map_weights(model_symbols_id_map=model_symbols_id_map,
                model_weights=model_weights,
                trained_weights=trained_weights,
                logger=logger)

    not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys()
    no_mapping = symbols_wo_mapping | not_existing_symbols
    if len(no_mapping) > 0:
        logger.warning(f"Following symbols were not mapped: {no_mapping}")
    else:
        logger.info("All symbols were mapped.")

    return model_weights
コード例 #2
0
  def test_get_symbols_id_mapping_without_map(self):
    from_symbols = {"b", "c"}
    to_symbols = {"a", "b"}
    from_conv = SymbolIdDict.init_from_symbols(from_symbols)
    to_conv = SymbolIdDict.init_from_symbols(to_symbols)
    mapping = SymbolsMap.from_intersection(from_symbols, to_symbols)
    mapping.update_existing_to_mappings({"a": "c"})

    symbols_id_map = mapping.convert_to_symbols_ids_map(
      to_symbols=to_conv,
      from_symbols=from_conv,
    )

    self.assertEqual(symbols_id_map[from_conv.get_id("b")], to_conv.get_id("b"))
    self.assertEqual(symbols_id_map[from_conv.get_id("c")], to_conv.get_id("a"))
    self.assertEqual(2, len(symbols_id_map))
コード例 #3
0
  def test_from_two_sets(self):
    m = SymbolsMap.from_intersection({"a", "b"}, {"b", "c"})

    self.assertEqual(2, len(m))
    self.assertEqual("b", m["b"])
    self.assertEqual("", m["c"])