Exemplo n.º 1
0
  def test_update_existing_to_mappings_overwrites_empty_mapping(self):
    old = SymbolsMap({"a": ""})

    old.update_existing_to_mappings({"a": "a"})

    self.assertEqual(1, len(old))
    self.assertEqual("a", old["a"])
Exemplo n.º 2
0
  def test_update_existing_to_mappings_ignores_mapping_with_empty_symbol(self):
    old = SymbolsMap({"a": "a"})

    old.update_existing_to_mappings({"a": ""})

    self.assertEqual(1, len(old))
    self.assertEqual("a", old["a"])
Exemplo n.º 3
0
  def test_create_or_update_map_with_template_map_and_existing_map(self):
    orig_symbols = {"c", "d", "e"}
    corpora = {"a", "b", "c"}

    existing_map = SymbolsMap({
      "a": "a",
      "b": "b",
      "c": "c",
    })

    template_map = SymbolsMap({
      "b": "d",
      "e": "f"
    })

    symbols_id_map, symbols = create_or_update_map(
      orig=orig_symbols,
      dest=corpora,
      template_map=template_map,
      existing_map=None,
    )

    self.assertEqual(symbols_id_map["a"], "")
    self.assertEqual(symbols_id_map["b"], "d")
    self.assertEqual(symbols_id_map["c"], "c")
    self.assertEqual(3, len(symbols_id_map))
    self.assertEqual(["c", "d", "e"], symbols)
Exemplo n.º 4
0
 def test_save_load_symbols_map(self):
   path = tempfile.mktemp()
   symbols_map = SymbolsMap({
     "b": "a",
     "c": "b",
     "x": "y",
   })
   symbols_map.save(path)
   res = SymbolsMap.load(path)
   os.remove(path)
   self.assertEqual(3, len(res))
   self.assertEqual("a", res["b"])
   self.assertEqual("b", res["c"])
   self.assertEqual("y", res["x"])
Exemplo n.º 5
0
def create_or_update_inference_map_main(base_dir: str,
                                        prep_name: str,
                                        template_map: Optional[str] = None):
    logger = init_logger()
    add_console_out_to_logger(logger)
    logger.info("Creating/updating inference map...")
    prep_dir = get_prepared_dir(base_dir, prep_name)
    assert os.path.isdir(prep_dir)

    all_symbols = get_all_symbols(prep_dir)

    if template_map is not None:
        _template_map = SymbolsMap.load(template_map)
    else:
        _template_map = None

    if infer_map_exists(prep_dir):
        existing_map = load_infer_map(prep_dir)
    else:
        existing_map = None

    infer_map, symbols = create_or_update_inference_map(
        orig=load_prep_symbol_converter(prep_dir).get_all_symbols(),
        dest=all_symbols,
        existing_map=existing_map,
        template_map=_template_map,
    )

    save_infer_map(prep_dir, infer_map)
    save_infer_symbols(prep_dir, symbols)
Exemplo n.º 6
0
def create_or_update_weights_map_main(base_dir: str,
                                      prep_name: str,
                                      weights_prep_name: str,
                                      template_map: Optional[str] = None):
    prep_dir = get_prepared_dir(base_dir, prep_name)
    assert os.path.isdir(prep_dir)
    orig_prep_dir = get_prepared_dir(base_dir, weights_prep_name)
    assert os.path.isdir(orig_prep_dir)

    logger = init_logger()
    add_console_out_to_logger(logger)
    logger.info(f"Creating/updating weights map for {weights_prep_name}...")

    if template_map is not None:
        _template_map = SymbolsMap.load(template_map)
    else:
        _template_map = None

    if weights_map_exists(prep_dir, weights_prep_name):
        existing_map = load_weights_map(prep_dir, weights_prep_name)
    else:
        existing_map = None

    weights_map, symbols = create_or_update_weights_map(
        orig=load_prep_symbol_converter(orig_prep_dir).get_all_symbols(),
        dest=load_prep_symbol_converter(prep_dir).get_all_symbols(),
        existing_map=existing_map,
        template_map=_template_map,
    )

    save_weights_map(prep_dir, weights_prep_name, weights_map)
    save_weights_symbols(prep_dir, weights_prep_name, symbols)
Exemplo n.º 7
0
def get_mapped_symbol_weights(model_symbols: SymbolIdDict,
                              trained_weights: Tensor,
                              trained_symbols: SymbolIdDict,
                              custom_mapping: Optional[SymbolsMap],
                              hparams: HParams, logger: Logger) -> Tensor:
    symbols_match_not_model = trained_weights.shape[0] != len(trained_symbols)
    if symbols_match_not_model:
        logger.exception(
            f"Weights mapping: symbol space from pretrained model ({trained_weights.shape[0]}) did not match amount of symbols ({len(trained_symbols)})."
        )
        raise Exception()

    if custom_mapping is None:
        symbols_map = SymbolsMap.from_intersection(
            map_to=model_symbols.get_all_symbols(),
            map_from=trained_symbols.get_all_symbols(),
        )
    else:
        symbols_map = custom_mapping
        symbols_map.remove_unknown_symbols(
            known_to_symbol=model_symbols.get_all_symbols(),
            known_from_symbols=trained_symbols.get_all_symbols())

    # Remove all empty mappings
    symbols_wo_mapping = symbols_map.get_symbols_with_empty_mapping()
    symbols_map.pop_batch(symbols_wo_mapping)

    symbols_id_map = symbols_map.convert_to_symbols_ids_map(
        to_symbols=model_symbols,
        from_symbols=trained_symbols,
    )

    model_symbols_id_map = symbols_ids_map_to_model_symbols_ids_map(
        symbols_id_map,
        hparams.n_accents,
        n_symbols=hparams.n_symbols,
        accents_use_own_symbols=hparams.accents_use_own_symbols)

    model_weights = get_symbol_weights(hparams)

    map_weights(model_symbols_id_map=model_symbols_id_map,
                model_weights=model_weights,
                trained_weights=trained_weights,
                logger=logger)

    not_existing_symbols = model_symbols.get_all_symbols() - symbols_map.keys()
    no_mapping = symbols_wo_mapping | not_existing_symbols
    if len(no_mapping) > 0:
        logger.warning(f"Following symbols were not mapped: {no_mapping}")
    else:
        logger.info("All symbols were mapped.")

    return model_weights
Exemplo n.º 8
0
  def test_get_symbols_id_mapping_without_map(self):
    from_symbols = {"b", "c"}
    to_symbols = {"a", "b"}
    from_conv = SymbolIdDict.init_from_symbols(from_symbols)
    to_conv = SymbolIdDict.init_from_symbols(to_symbols)
    mapping = SymbolsMap.from_intersection(from_symbols, to_symbols)
    mapping.update_existing_to_mappings({"a": "c"})

    symbols_id_map = mapping.convert_to_symbols_ids_map(
      to_symbols=to_conv,
      from_symbols=from_conv,
    )

    self.assertEqual(symbols_id_map[from_conv.get_id("b")], to_conv.get_id("b"))
    self.assertEqual(symbols_id_map[from_conv.get_id("c")], to_conv.get_id("a"))
    self.assertEqual(2, len(symbols_id_map))
Exemplo n.º 9
0
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool) -> Tuple[SymbolIdDict, SentenceList]:
  sents_new_symbols = []
  result = SentenceList()
  new_sent_id = 0
  for sentence in sentences.items():
    symbols = text_symbols.get_symbols(sentence.serialized_symbols)
    accent_ids = deserialize_list(sentence.serialized_accents)

    mapped_symbols = symbols_map.apply_to_symbols(symbols)

    text = SymbolIdDict.symbols_to_text(mapped_symbols)
    # a resulting empty text would make no problems
    sents = split_sentences(text, sentence.lang)
    for new_sent_text in sents:
      new_symbols = text_to_symbols(
        new_sent_text,
        lang=sentence.lang,
        ignore_tones=False,
        ignore_arcs=ignore_arcs
      )

      if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
      else:
        new_accent_ids = []

      assert len(new_accent_ids) == len(new_symbols)

      new_sent_id += 1
      tmp = Sentence(
        sent_id=new_sent_id,
        text=new_sent_text,
        lang=sentence.lang,
        serialized_accents=serialize_list(new_accent_ids),
        serialized_symbols=""
      )
      sents_new_symbols.append(new_symbols)

      assert len(tmp.get_accent_ids()) == len(new_symbols)
      result.append(tmp)

  return update_symbols_and_text(result, sents_new_symbols)
Exemplo n.º 10
0
def map_text(base_dir: str,
             prep_name: str,
             text_name: str,
             symbols_map_path: str,
             ignore_arcs: bool = True):
    prep_dir = get_prepared_dir(base_dir, prep_name, create=False)
    text_dir = get_text_dir(prep_dir, text_name, create=False)
    if not os.path.isdir(text_dir):
        print("Please add text first.")
    else:
        symbol_ids, updated_sentences = sents_map(
            sentences=load_text_csv(text_dir),
            text_symbols=load_text_symbol_converter(text_dir),
            symbols_map=SymbolsMap.load(symbols_map_path),
            ignore_arcs=ignore_arcs)

        print("\n" + updated_sentences.get_formatted(
            symbol_id_dict=symbol_ids,
            accent_id_dict=load_prep_accents_ids(prep_dir)))
        _save_text_csv(text_dir, updated_sentences)
        save_text_symbol_converter(text_dir, symbol_ids)
        _accent_template(base_dir, prep_name, text_name)
        _check_for_unknown_symbols(base_dir, prep_name, text_name)
Exemplo n.º 11
0
  def test_from_two_sets(self):
    m = SymbolsMap.from_intersection({"a", "b"}, {"b", "c"})

    self.assertEqual(2, len(m))
    self.assertEqual("b", m["b"])
    self.assertEqual("", m["c"])
Exemplo n.º 12
0
def try_load_symbols_map(symbols_map_path: str) -> Optional[SymbolsMap]:
    symbols_map = SymbolsMap.load(
        symbols_map_path) if symbols_map_path else None
    return symbols_map
Exemplo n.º 13
0
def load_weights_map(prep_dir: str, orig_prep_name: str) -> SymbolsMap:
    path = os.path.join(prep_dir, f"{orig_prep_name}.json")
    return SymbolsMap.load(path)
Exemplo n.º 14
0
def save_weights_map(prep_dir: str, orig_prep_name: str,
                     weights_map: SymbolsMap):
    path = os.path.join(prep_dir, f"{orig_prep_name}.json")
    weights_map.save(path)
Exemplo n.º 15
0
def load_infer_map(prep_dir: str) -> SymbolsMap:
    return SymbolsMap.load(get_infer_map_path(prep_dir))
Exemplo n.º 16
0
def save_infer_map(prep_dir: str, infer_map: SymbolsMap):
    infer_map.save(get_infer_map_path(prep_dir))