Ejemplo n.º 1
0
def setup():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    global model
    global device

    char2index, index2char = label_loader.load_label_json(
        "../data/kor_syllable_zeroth.json")
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    print(f"device: {device}")

    input_size = int(161)
    enc = EncoderRNN(input_size,
                     512,
                     n_layers=3,
                     dropout_p=0.3,
                     bidirectional=True,
                     rnn_cell='LSTM',
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     128,
                     512,
                     512,
                     SOS_token,
                     EOS_token,
                     n_layers=2,
                     rnn_cell='LSTM',
                     dropout_p=0.3,
                     bidirectional_encoder=True)

    model = Seq2Seq(enc, dec).to(device)

    model_path = "../models/zeroth_korean_trimmed/LSTM_512x3_512x2_zeroth_korean_trimmed/final.pth"
    print("Loading checkpoint model %s" % model_path)
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state['model'])
    print('Model loaded')
Ejemplo n.º 2
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='LAS')
    parser.add_argument('--model-name', type=str, default='LAS')
    # Dataset
    parser.add_argument('--train-file',
                        type=str,
                        help='data list about train dataset',
                        default='data/ClovaCall/train_ClovaCall.json')
    parser.add_argument('--test-file-list',
                        nargs='*',
                        help='data list about test dataset',
                        default=['data/ClovaCall/test_ClovCall.json'])
    parser.add_argument('--labels-path',
                        default='data/kor_syllable.json',
                        help='Contains large characters over korean')
    parser.add_argument('--dataset-path',
                        default='data/ClovaCall/clean',
                        help='Target dataset path')
    # Hyperparameters
    parser.add_argument('--rnn-type',
                        default='lstm',
                        help='Type of the RNN. rnn|gru|lstm are supported')
    parser.add_argument('--encoder_layers',
                        type=int,
                        default=3,
                        help='number of layers of model (default: 3)')
    parser.add_argument('--encoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--decoder_layers',
                        type=int,
                        default=2,
                        help='number of pyramidal layers (default: 2)')
    parser.add_argument('--decoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.3,
                        help='Dropout rate in training (default: 0.3)')
    parser.add_argument(
        '--no-bidirectional',
        dest='bidirectional',
        action='store_false',
        default=True,
        help='Turn off bi-directional RNNs, introduces lookahead convolution')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size in training (default: 32)')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=4,
        help='Number of workers in dataset loader (default: 4)')
    parser.add_argument('--num_gpu',
                        type=int,
                        default=1,
                        help='Number of gpus (default: 1)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='Number of max epochs in training (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=3e-4,
                        help='Learning rate (default: 3e-4)')
    parser.add_argument('--learning-anneal',
                        default=1.1,
                        type=float,
                        help='Annealing learning rate every epoch')
    parser.add_argument('--teacher_forcing',
                        type=float,
                        default=1.0,
                        help='Teacher forcing ratio in decoder (default: 1.0)')
    parser.add_argument('--max_len',
                        type=int,
                        default=80,
                        help='Maximum characters of sentence (default: 80)')
    parser.add_argument('--max-norm',
                        default=400,
                        type=int,
                        help='Norm cutoff to prevent explosion of gradients')
    # Audio Config
    parser.add_argument('--sample-rate',
                        default=16000,
                        type=int,
                        help='Sampling Rate')
    parser.add_argument('--window-size',
                        default=.02,
                        type=float,
                        help='Window size for spectrogram')
    parser.add_argument('--window-stride',
                        default=.01,
                        type=float,
                        help='Window stride for spectrogram')
    # System
    parser.add_argument('--save-folder',
                        default='models',
                        help='Location to save epoch models')
    parser.add_argument('--model-path',
                        default='models/las_final.pth',
                        help='Location to save best validation model')
    parser.add_argument(
        '--log-path',
        default='log/',
        help='path to predict log about valid and test dataset')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=123456,
                        help='random seed (default: 123456)')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='Train or Test')
    parser.add_argument('--load-model',
                        action='store_true',
                        default=False,
                        help='Load model')
    parser.add_argument('--finetune',
                        dest='finetune',
                        action='store_true',
                        default=False,
                        help='Finetune the model after load model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    char2index, index2char = label_loader.load_label_json(args.labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if args.cuda else 'cpu')

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride)

    # Batch Size
    batch_size = args.batch_size * args.num_gpu

    print(">> Train dataset : ", args.train_file)
    trainData_list = []
    with open(args.train_file, 'r', encoding='utf-8') as f:
        trainData_list = json.load(f)

    if args.num_gpu != 1:
        last_batch = len(trainData_list) % batch_size
        if last_batch != 0 and last_batch < args.num_gpu:
            trainData_list = trainData_list[:-last_batch]

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       dataset_path=args.dataset_path,
                                       data_list=trainData_list,
                                       char2index=char2index,
                                       sos_id=SOS_token,
                                       eos_id=EOS_token,
                                       normalize=True)

    train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    print(">> Test dataset : ", args.test_file_list)
    testLoader_dict = {}
    for test_file in args.test_file_list:
        testData_list = []
        with open(test_file, 'r', encoding='utf-8') as f:
            testData_list = json.load(f)

        test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                          dataset_path=args.dataset_path,
                                          data_list=testData_list,
                                          char2index=char2index,
                                          sos_id=SOS_token,
                                          eos_id=EOS_token,
                                          normalize=True)
        testLoader_dict[test_file] = AudioDataLoader(
            test_dataset, batch_size=1, num_workers=args.num_workers)

    input_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
                     dropout_p=args.dropout,
                     bidirectional=args.bidirectional,
                     rnn_cell=args.rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     args.max_len,
                     args.decoder_size,
                     args.encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=args.decoder_layers,
                     rnn_cell=args.rnn_type,
                     dropout_p=args.dropout,
                     bidirectional_encoder=args.bidirectional)

    model = Seq2Seq(enc, dec)

    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)

    optim_state = None
    if args.load_model:  # Starting from previous model
        print("Loading checkpoint model %s" % args.model_path)
        state = torch.load(args.model_path)
        model.load_state_dict(state['model'])
        print('Model loaded')

        if not args.finetune:  # Just load model
            optim_state = state['optimizer']

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    if optim_state is not None:
        optimizer.load_state_dict(optim_state)

    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    train_model = nn.DataParallel(model)

    if args.mode != "train":
        for test_file in args.test_file_list:
            test_loader = testLoader_dict[test_file]
            test_loss, test_cer, transcripts_list = evaluate(model,
                                                             test_loader,
                                                             criterion,
                                                             device,
                                                             save_output=True)

            for idx, line in enumerate(transcripts_list):
                # print(line)
                hyp, ref = line.split('\t')
                print("({:3d}/{:3d}) [REF]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       ref))
                print("({:3d}/{:3d}) [HYP]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       hyp))
                print()

            print("Test {} CER : {}".format(test_file, test_cer))
    else:
        best_cer = 1e10
        begin_epoch = 0

        # start_time = time.time()
        start_time = datetime.datetime.now()

        for epoch in range(begin_epoch, args.epochs):
            train_loss, train_cer = train(train_model, train_loader, criterion,
                                          optimizer, device, epoch,
                                          train_sampler, args.max_norm,
                                          args.teacher_forcing)

            # end_time = time.time()
            # elapsed_time = end_time - start_time
            elapsed_time = datetime.datetime.now() - start_time

            train_log = 'Train({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\tTime {time:}'.format(
                epoch + 1,
                name='train',
                loss=train_loss,
                cer=train_cer,
                time=elapsed_time)
            print(train_log)

            cer_list = []
            for test_file in args.test_file_list:
                test_loader = testLoader_dict[test_file]
                test_loss, test_cer, _ = evaluate(model,
                                                  test_loader,
                                                  criterion,
                                                  device,
                                                  save_output=False)
                test_log = 'Test({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\t'.format(
                    epoch + 1, name=test_file, loss=test_loss, cer=test_cer)
                print(test_log)

                cer_list.append(test_cer)

            if best_cer > cer_list[0]:
                print("Found better validated model, saving to %s" %
                      args.model_path)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, args.model_path)
                best_cer = cer_list[0]

            print("Shuffling batches...")
            train_sampler.shuffle(epoch)

            for g in optimizer.param_groups:
                g['lr'] = g['lr'] / args.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))
Ejemplo n.º 3
0
def main():
    os.chdir(os.path.dirname(__file__))
    os.chdir('../')

    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    model_name = 'LAS'
    # Dataset
    test_file_list = ['data/Youtube_test/youtube_test.json']
    labels_path = 'data/kor_syllable.json'
    dataset_path = 'data/Youtube_test'

    # Hyperparameters
    rnn_type = 'lstm'
    encoder_layers = 3
    encoder_size = 512
    decoder_layers = 2
    decoder_size = 512
    dropout = 0.3
    bidirectional = True
    num_workers = 4
    max_len = 80

    # Audio Config
    sample_rate = 16000
    window_size = .02
    window_stride = .01

    # System
    save_folder = 'models'
    model_path = 'models/AIHub_train/LSTM_512x3_512x2_AIHub_train/final.pth'
    cuda = True
    seed = 123456

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    char2index, index2char = label_loader.load_label_json(labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if cuda else 'cpu')

    audio_conf = dict(sample_rate=sample_rate,
                      window_size=window_size,
                      window_stride=window_stride)


    print(">> Test dataset : ", test_file_list)
    testLoader_dict = {}


    for test_file in test_file_list:
        testData_list = []
        with open(test_file, 'r', encoding='utf-8') as f:
            testData_list = json.load(f)
        
        test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                          dataset_path=dataset_path, 
                                          data_list=testData_list,
                                          char2index=char2index, sos_id=SOS_token, eos_id=EOS_token,
                                          normalize=True)
        testLoader_dict[test_file] = AudioDataLoader(test_dataset, batch_size=1, num_workers=num_workers)


    input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
    enc = EncoderRNN(input_size, encoder_size, n_layers=encoder_layers,
                     dropout_p=dropout, bidirectional=bidirectional, 
                     rnn_cell=rnn_type, variable_lengths=False)

    dec = DecoderRNN(len(char2index), max_len, decoder_size, encoder_size,
                     SOS_token, EOS_token,
                     n_layers=decoder_layers, rnn_cell=rnn_type, 
                     dropout_p=dropout, bidirectional_encoder=bidirectional)


    model = Seq2Seq(enc, dec)
    os.makedirs(save_folder, exist_ok=True)


    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Loading checkpoint model %s" % model_path)
    state = torch.load(model_path)
    model.load_state_dict(state['model'])
    print('Model loaded')

    model = model.to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    for test_file in test_file_list:
        test_loader = testLoader_dict[test_file]
        test_loss, test_cer, transcripts_list = evaluate(model, test_loader, criterion, device, save_output=True)

        for line in transcripts_list:
            print('STT : ' + line.split('\t')[0])
            print('정답 : ' + line.split('\t')[1])
            print('-'*100)

        print("Test {} CER : {}".format(test_file, test_cer))
Ejemplo n.º 4
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='LAS')
    parser.add_argument('--model-name', type=str, default='LAS')
    # Dataset
    parser.add_argument('--test-file-list',
                        nargs='*',
                        help='data list about test dataset',
                        default=['data/Youtube/clean'])
    parser.add_argument('--labels-path',
                        default='data/kor_syllable.json',
                        help='Contains large characters over korean')
    parser.add_argument('--dataset-path',
                        default='data/Youtube/clean',
                        help='Target dataset path')

    # Hyperparameters
    parser.add_argument('--rnn-type',
                        default='lstm',
                        help='Type of the RNN. rnn|gru|lstm are supported')
    parser.add_argument('--encoder_layers',
                        type=int,
                        default=3,
                        help='number of layers of model (default: 3)')
    parser.add_argument('--encoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--decoder_layers',
                        type=int,
                        default=2,
                        help='number of pyramidal layers (default: 2)')
    parser.add_argument('--decoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument(
        '--no-bidirectional',
        dest='bidirectional',
        action='store_false',
        default=True,
        help='Turn off bi-directional RNNs, introduces lookahead convolution')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=4,
        help='Number of workers in dataset loader (default: 4)')
    parser.add_argument('--num_gpu',
                        type=int,
                        default=1,
                        help='Number of gpus (default: 1)')
    parser.add_argument('--learning-anneal',
                        default=1.1,
                        type=float,
                        help='Annealing learning rate every epoch')
    parser.add_argument('--max_len',
                        type=int,
                        default=80,
                        help='Maximum characters of sentence (default: 80)')
    parser.add_argument('--max-norm',
                        default=400,
                        type=int,
                        help='Norm cutoff to prevent explosion of gradients')

    # Audio Config
    parser.add_argument('--sample-rate',
                        default=16000,
                        type=int,
                        help='Sampling Rate')
    parser.add_argument('--window-size',
                        default=.02,
                        type=float,
                        help='Window size for spectrogram')
    parser.add_argument('--window-stride',
                        default=.01,
                        type=float,
                        help='Window stride for spectrogram')

    # System
    parser.add_argument('--model-path',
                        default='models/final.pth',
                        help='Location to save best validation model')
    parser.add_argument('--seed',
                        type=int,
                        default=123456,
                        help='random seed (default: 123456)')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    char2index, index2char = label_loader.load_label_json(args.labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride)

    print(">> Test dataset : ", args.dataset_path)

    testData_list = []

    for wav_name in sorted(glob(f'{args.dataset_path}/*'),
                           key=lambda i: int(os.path.basename(i)[:-4])):
        wav_dict = {}
        wav_dict['wav'] = os.path.basename(wav_name)
        wav_dict['text'] = os.path.basename(wav_name)
        wav_dict['speaker_id'] = '0'
        testData_list.append(wav_dict)

    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      dataset_path=args.dataset_path,
                                      data_list=testData_list,
                                      char2index=char2index,
                                      sos_id=SOS_token,
                                      eos_id=EOS_token,
                                      normalize=True)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers)

    input_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
                     bidirectional=args.bidirectional,
                     rnn_cell=args.rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     args.max_len,
                     args.decoder_size,
                     args.encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=args.decoder_layers,
                     rnn_cell=args.rnn_type,
                     bidirectional_encoder=args.bidirectional)

    model = Seq2Seq(enc, dec)

    print("Loading checkpoint model %s" % args.model_path)
    state = torch.load(args.model_path)
    model.load_state_dict(state['model'])

    model = model.to(device)
    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    test_loss, test_cer, transcripts_list = evaluate(model,
                                                     test_loader,
                                                     criterion,
                                                     device,
                                                     save_output=True)

    print(f"{'true':^20} | {'pred':^20}")
    for line in transcripts_list:
        print(line)

    print("Test {} CER : {}".format("test", test_cer))
Ejemplo n.º 5
0
def wav_to_text():
    os.chdir(os.path.dirname(__file__))
    os.chdir('../')

    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    model_name = 'LAS'
    # Dataset
    target_file_lst = glob.glob('data/youtube/*/chunks/*.pcm')
    target_dic = []
    for chunk in target_file_lst:
        target_dic.append({'wav': chunk, 'text': '1'})
    with open('data/youtube/target_file.json', 'w',
              encoding='utf-8') as json_file:
        json.dump(target_dic, json_file)
    target_file_list = ['data/youtube/target_file.json']
    labels_path = 'data/kor_syllable.json'
    dataset_path = ''

    # Hyperparameters
    rnn_type = 'lstm'
    encoder_layers = 3
    encoder_size = 512
    decoder_layers = 2
    decoder_size = 512
    dropout = 0.3
    bidirectional = True
    num_workers = 4
    max_len = 80

    # Audio Config
    sample_rate = 16000
    window_size = .02
    window_stride = .01

    # System
    save_folder = 'models'
    model_path = 'models/news_finetune/final.pth'
    cuda = True
    seed = 123456

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    char2index, index2char = label_loader.load_label_json(labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if cuda else 'cpu')

    audio_conf = dict(sample_rate=sample_rate,
                      window_size=window_size,
                      window_stride=window_stride)

    print(">> Target dataset : ", target_file_list)
    targetLoader_dict = {}

    for target_file in target_file_list:
        targetData_list = []
        with open(target_file, 'r', encoding='utf-8') as f:
            targetData_list = json.load(f)

        target_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                            dataset_path=dataset_path,
                                            data_list=targetData_list,
                                            char2index=char2index,
                                            sos_id=SOS_token,
                                            eos_id=EOS_token,
                                            normalize=True)
        targetLoader_dict[target_file] = AudioDataLoader(
            target_dataset, batch_size=1, num_workers=num_workers)

    input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     encoder_size,
                     n_layers=encoder_layers,
                     dropout_p=dropout,
                     bidirectional=bidirectional,
                     rnn_cell=rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     max_len,
                     decoder_size,
                     encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=decoder_layers,
                     rnn_cell=rnn_type,
                     dropout_p=dropout,
                     bidirectional_encoder=bidirectional)

    model = Seq2Seq(enc, dec)
    os.makedirs(save_folder, exist_ok=True)

    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Loading checkpoint model %s" % model_path)
    state = torch.load(model_path)
    model.load_state_dict(state['model'])
    print('Model loaded')

    model = model.to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    result = []
    for target_file in target_file_list:
        target_loader = targetLoader_dict[target_file]
        transcripts_list = evaluate(model,
                                    target_loader,
                                    criterion,
                                    device,
                                    save_output=True)
        result.append(' '.join(
            [line.split('\t')[0] for line in transcripts_list]))

    lst = os.listdir('data/youtube')
    for x in lst:
        if x != 'target_file.json':
            shutil.rmtree('data/youtube/' + x)

    return result