Beispiel #1
0
def transfer_am(train_config):
    """
    initialize the acoustic model with a pretrained model for fine-tuning

    :param model_path: path to the
    :return:
    """

    pretrained_model_path = get_model_path(train_config.pretrained_model)

    am_config = Namespace(**json.load(open(str(pretrained_model_path / 'am_config.json'))))

    assert am_config.model == 'allosaurus', "This project only support allosaurus model"

    # load inventory
    inventory = Inventory(pretrained_model_path)

    # get unit_mask which maps the full phone inventory to the target phone inventory
    unit_mask = inventory.get_mask(train_config.lang, approximation=True)

    # reset the new phone_size
    am_config.phone_size = len(unit_mask.target_unit)

    model = AllosaurusTorchModel(am_config)

    # load the pretrained model and setup the phone_layer with correct weights
    torch_load(model, str(pretrained_model_path / 'model.pt'), train_config.device_id, unit_mask)

    # update new model
    new_model = train_config.new_model

    # get its path
    model_path = get_model_path(new_model)

    # overwrite old am_config
    new_am_config_json = vars(am_config)
    json.dump(new_am_config_json, open(str(model_path / 'am_config.json'), 'w'), indent=4)

    # overwrite old phones
    write_unit(unit_mask.target_unit, model_path / 'phone.txt')

    # overwrite old model
    torch_save(model, model_path / 'model.pt')

    return model
Beispiel #2
0
    def __init__(self, model, train_config):

        self.model = model
        self.train_config = train_config

        self.device_id = self.train_config.device_id

        # criterion, only ctc currently
        self.criterion = read_criterion(train_config)

        # optimizer, only sgd currently
        self.optimizer = read_optimizer(self.model, train_config)

        # reporter to write logs
        self.reporter = Reporter(train_config)

        # best per
        self.best_per = 100.0

        # intialize the model
        self.model_path = get_model_path(train_config.new_model)

        # counter for early stopping
        self.num_no_improvement = 0
Beispiel #3
0
from pathlib import Path
from allosaurus.lm.inventory import Inventory
from allosaurus.model import get_model_path
import argparse

if __name__ == '__main__':

    parser = argparse.ArgumentParser('Update language inventory')
    parser.add_argument('-l', '--lang',  type=str, required=True, help='specify which language inventory to update.')
    parser.add_argument('-m', '--model', type=str, default='latest', help='specify which model inventory')
    parser.add_argument('-i', '--input', type=str, required=True, help='your new inventory file')

    args = parser.parse_args()

    model_path = get_model_path(args.model)

    inventory = Inventory(model_path)

    lang = args.lang

    # verify lang is not ipa as it is an alias to the entire inventory
    assert args.lang != 'ipa', "ipa is not a proper lang to update. use list_lang to find a proper language"

    assert lang.lower() in inventory.lang_ids or lang.lower() in inventory.glotto_ids, f'language {args.lang} is not supported. Please verify it is in the language list'

    new_unit_file = Path(args.input)

    # check existence of the file
    assert new_unit_file.exists(), args.input+' does not exist'

    # update this new unit