def transfer_am(train_config): """ initialize the acoustic model with a pretrained model for fine-tuning :param model_path: path to the :return: """ pretrained_model_path = get_model_path(train_config.pretrained_model) am_config = Namespace(**json.load(open(str(pretrained_model_path / 'am_config.json')))) assert am_config.model == 'allosaurus', "This project only support allosaurus model" # load inventory inventory = Inventory(pretrained_model_path) # get unit_mask which maps the full phone inventory to the target phone inventory unit_mask = inventory.get_mask(train_config.lang, approximation=True) # reset the new phone_size am_config.phone_size = len(unit_mask.target_unit) model = AllosaurusTorchModel(am_config) # load the pretrained model and setup the phone_layer with correct weights torch_load(model, str(pretrained_model_path / 'model.pt'), train_config.device_id, unit_mask) # update new model new_model = train_config.new_model # get its path model_path = get_model_path(new_model) # overwrite old am_config new_am_config_json = vars(am_config) json.dump(new_am_config_json, open(str(model_path / 'am_config.json'), 'w'), indent=4) # overwrite old phones write_unit(unit_mask.target_unit, model_path / 'phone.txt') # overwrite old model torch_save(model, model_path / 'model.pt') return model
def __init__(self, model, train_config): self.model = model self.train_config = train_config self.device_id = self.train_config.device_id # criterion, only ctc currently self.criterion = read_criterion(train_config) # optimizer, only sgd currently self.optimizer = read_optimizer(self.model, train_config) # reporter to write logs self.reporter = Reporter(train_config) # best per self.best_per = 100.0 # intialize the model self.model_path = get_model_path(train_config.new_model) # counter for early stopping self.num_no_improvement = 0
from pathlib import Path from allosaurus.lm.inventory import Inventory from allosaurus.model import get_model_path import argparse if __name__ == '__main__': parser = argparse.ArgumentParser('Update language inventory') parser.add_argument('-l', '--lang', type=str, required=True, help='specify which language inventory to update.') parser.add_argument('-m', '--model', type=str, default='latest', help='specify which model inventory') parser.add_argument('-i', '--input', type=str, required=True, help='your new inventory file') args = parser.parse_args() model_path = get_model_path(args.model) inventory = Inventory(model_path) lang = args.lang # verify lang is not ipa as it is an alias to the entire inventory assert args.lang != 'ipa', "ipa is not a proper lang to update. use list_lang to find a proper language" assert lang.lower() in inventory.lang_ids or lang.lower() in inventory.glotto_ids, f'language {args.lang} is not supported. Please verify it is in the language list' new_unit_file = Path(args.input) # check existence of the file assert new_unit_file.exists(), args.input+' does not exist' # update this new unit