예제 #1
0
def main():
    testset = TextDataset(args.testset)
    test_loader = DataLoader(dataset=testset,
                             batch_size=args.test_batch,
                             drop_last=False,
                             shuffle=False,
                             collate_fn=synth_collate_fn,
                             pin_memory=True)

    t2m = Text2Mel().to(DEVICE)
    ssrn = SSRN().to(DEVICE)

    mname = type(t2m).__name__
    ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar')))
    state = torch.load(ckpt[-1])
    t2m.load_state_dict(state['model'])
    args.global_step = state['global_step']

    mname = type(ssrn).__name__
    ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar')))
    state = torch.load(ckpt[-1])
    ssrn.load_state_dict(state['model'])

    print('All of models are loaded.')

    t2m.eval()
    ssrn.eval()

    if not os.path.exists(os.path.join(args.sampledir, 'A')):
        os.makedirs(os.path.join(args.sampledir, 'A'))
    synthesize(t2m, ssrn, test_loader, args.test_batch)
예제 #2
0
def main(network):
    if network == 'text2mel':
        model = Text2Mel().to(DEVICE)
    elif network == 'ssrn':
        model = SSRN().to(DEVICE)
    else:
        print('Wrong network. {text2mel, ssrn}')
        return
    print('Model {} is working...'.format(type(model).__name__))
    print('{} threads are used...'.format(torch.get_num_threads()))
    ckpt_dir = os.path.join(args.logdir, type(model).__name__)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[50000, 150000, 300000], gamma=0.5) #

    if not os.path.exists(ckpt_dir):
        os.makedirs(os.path.join(ckpt_dir, 'A', 'train'))
    else:
        print('Already exists. Retrain the model.')
        ckpt = sorted(glob.glob(os.path.join(ckpt_dir, '*k.pth.tar')))[-1]
        state = torch.load(ckpt)
        model.load_state_dict(state['model'])
        args.global_step = state['global_step']
        optimizer.load_state_dict(state['optimizer'])
        # scheduler.load_state_dict(state['scheduler'])

    # model = torch.nn.DataParallel(model, device_ids=list(range(args.no_gpu))).to(DEVICE)
    if type(model).__name__ == 'Text2Mel':
        if args.ga_mode:
            cfn_train, cfn_eval = t2m_ga_collate_fn, t2m_collate_fn
        else:
            cfn_train, cfn_eval = t2m_collate_fn, t2m_collate_fn
    else:
        cfn_train, cfn_eval = collate_fn, collate_fn

    dataset = SpeechDataset(args.data_path, args.meta_train, type(model).__name__, mem_mode=args.mem_mode, ga_mode=args.ga_mode)
    validset = SpeechDataset(args.data_path, args.meta_eval, type(model).__name__, mem_mode=args.mem_mode)
    data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size,
                             shuffle=True, collate_fn=cfn_train,
                             drop_last=True, pin_memory=True)
    valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch,
                              shuffle=False, collate_fn=cfn_eval, pin_memory=True)
    
    writer = SummaryWriter(ckpt_dir)
    train(model, data_loader, valid_loader, optimizer, scheduler,
          batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer)
    return None
예제 #3
0
                                           mode='valid')
else:
    if args.dataset == 'ljspeech':
        from datasets.lj_speech import vocab, LJSpeech as SpeechDataset
    elif args.dataset == 'mbspeech':
        from datasets.mb_speech import vocab, MBSpeech as SpeechDataset
    train_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(
        ['texts', 'mels', 'mel_gates']),
                                           batch_size=64,
                                           mode='train')
    valid_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(
        ['texts', 'mels', 'mel_gates']),
                                           batch_size=64,
                                           mode='valid')

text2mel = Text2Mel(vocab).cuda()
"""
if args.warmstart:
    old_lr = hp.text2mel_lr 
    hp.text2mel_lr = hp.text2mel_lr / 10.0
    print("Reducing learning rate from %.9f to %.9f because of warmstart" % (old_lr, hp.text2mel_lr))
"""

optimizer = torch.optim.Adam(text2mel.parameters(), lr=hp.text2mel_lr)

start_timestamp = int(time.time() * 1000)
start_epoch = 0
global_step = 0

logger = Logger(args.dataset, 'text2mel')
예제 #4
0
import warnings
warnings.filterwarnings("ignore")  
from playsound import playsound
import numpy as np
import torch
from num2words import num2words 
from hparams import HParams as hp
from audio import save_to_wav
from models import SSRN,Text2Mel
from lj_speech import vocab, idx2char, get_test_data

torch.set_grad_enabled(False)
text2mel = Text2Mel(vocab)
text2mel.load_state_dict(torch.load("ljspeech-text2mel.pth").state_dict())
text2mel = text2mel.eval()
ssrn = SSRN()
ssrn.load_state_dict(torch.load("ljspeech-ssrn.pth").state_dict())
ssrn = ssrn.eval()


def say(sentence):
    new_sentence=" " .join([num2words(w) if w.isdigit()  else w for w in sentence.split()])
    normalized_sentence = "".join([c if c.lower() in vocab else '' for c in new_sentence])
    print(normalized_sentence)
    sentences = [normalized_sentence]
    max_N = len(normalized_sentence)
    L = torch.from_numpy(get_test_data(sentences, max_N))
    zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32))
    Y = zeros
    A = None
예제 #5
0
    SENTENCES = [
        "Нийслэлийн прокурорын газраас төрийн өндөр албан тушаалтнуудад холбогдох зарим эрүүгийн хэргүүдийг шүүхэд шилжүүлэв.",
        "Мөнх тэнгэрийн хүчин дор Монгол Улс цэцэглэн хөгжих болтугай.",
        "Унасан хүлгээ түрүү магнай, аман хүзүүнд уралдуулж, айрагдуулсан унаач хүүхдүүдэд бэлэг гардууллаа.",
        "Албан ёсоор хэлэхэд “Монгол Улсын хэрэг эрхлэх газрын гэгээнтэн” гэж нэрлээд байгаа зүйл огт байхгүй.",
        "Сайн чанарын бохирын хоолой зарна.",
        "Хараа тэглэх мэс заслын дараа хараа дахин муудах магадлал бага.",
        "Ер нь бол хараа тэглэх мэс заслыг гоо сайхны мэс засалтай адилхан гэж зүйрлэж болно.",
        "Хашлага даван, зүлэг гэмтээсэн жолоочийн эрхийг хоёр жилээр хасжээ.",
        "Монгол хүн бидний сэтгэлийг сорсон орон. Энэ бол миний төрсөн нутаг. Монголын сайхан орон.",
        "Постройка крейсера затягивалась из-за проектных неувязок, необходимости."
    ]

torch.set_grad_enabled(False)

text2mel = Text2Mel(vocab).eval()
last_checkpoint_file_name = get_last_checkpoint_file_name(
    os.path.join(hp.logdir, '%s-text2mel' % args.dataset))
# last_checkpoint_file_name = 'logdir/%s-text2mel/step-020K.pth' % args.dataset
if last_checkpoint_file_name:
    print("loading text2mel checkpoint '%s'..." % last_checkpoint_file_name)
    load_checkpoint(last_checkpoint_file_name, text2mel, None)
else:
    print("text2mel not exits")
    sys.exit(1)

ssrn = SSRN().eval()
last_checkpoint_file_name = get_last_checkpoint_file_name(
    os.path.join(hp.logdir, '%s-ssrn' % args.dataset))
# last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset
if last_checkpoint_file_name:
load_checkpoint('trained/ssrn/lj/step-140K.pth', ssrn, None)
# last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-ssrn' % args.dataset))
# last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset
# if last_checkpoint_file_name:
#     print("loading ssrn checkpoint '%s'..." % last_checkpoint_file_name)
#     load_checkpoint(last_checkpoint_file_name, ssrn, None)
# else:
#     print("ssrn not exits")
#     sys.exit(1)
if not os.path.isdir(f'samples'):
    os.mkdir(f'samples')

for t2m in t2m_list:
    filename = os.path.splitext(os.path.basename(t2m))[0]
    folder = os.path.split(os.path.split(t2m)[0])[-1]
    text2mel = Text2Mel(vocab).to(device).eval()
    print("loading text2mel...")
    load_checkpoint(t2m, text2mel, None)
    # text2mel = Text2Mel(vocab)
    # text2mel.load_state_dict(torch.load(t2m).state_dict())
    # text2mel = text2mel.eval()
    for sentence in SENTENCES:
        with torch.no_grad():
            L = torch.from_numpy(get_test_data(sentence)).to(device)
            zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)).to(device)
            Y = zeros
            # A = None

            while True:
                _, Y_t, A = text2mel(L, Y, monotonic_attention=True)
                Y = torch.cat((zeros, Y_t), -1)