예제 #1
0
    if args.decoder == "beam":
        scorer = None
        if args.lm_path is not None:
            scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
            scorer.set_lm_weight(args.lm_alpha)
            scorer.set_word_weight(args.lm_beta1)
            scorer.set_valid_word_weight(args.lm_beta2)
        else:
            scorer = Scorer()
        decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
    else:
        decoder = GreedyDecoder(labels, space_index=labels.index('<space>'), blank_index=labels.index('_'))

    test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.test_manifest, labels=labels,
                                      normalize=True)
    test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
                                  num_workers=args.num_workers)
    total_cer, total_wer = 0, 0
    for i, (data) in enumerate(test_loader):
        inputs, targets, input_percentages, target_sizes = data

        inputs = Variable(inputs, volatile=True)

        # unflatten targets
        split_targets = []
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        if args.cuda:
            inputs = inputs.cuda()
예제 #2
0
def main():
    global args, train_logger, test_logger
    args = options.parse_args()
    os.makedirs(args.log_dir)
    test_logger = Logger(os.path.join(args.log_dir, 'test.log'))
    with open(os.path.join(args.log_dir, 'config.log'), 'w') as f:
        f.write(args.config_str)
    if not args.evaluate:
        os.makedirs(args.checkpoint_dir)
        train_logger = Logger(os.path.join(args.log_dir, 'train.log'))
    loss_results, cer_results = torch.FloatTensor(
        args.epochs), torch.FloatTensor(args.epochs)

    if args.visdom:
        from visdom import Visdom
        viz = Visdom()
        opts = dict(title=args.experiment_id,
                    ylabel='',
                    xlabel='Epoch',
                    legend=['Loss', 'CER'])
        viz_windows = None
        epochs = torch.arange(0, args.epochs)

    if args.resume:
        print('Loading checkpoint model %s' % args.resume)
        checkpoint = torch.load(args.resume)
        model = DeepSpeech.load_model_checkpoint(checkpoint)
        model = torch.nn.DataParallel(model,
                                      device_ids=[i for i in range(args.nGPU)
                                                  ]).cuda()
        labels = DeepSpeech.get_labels(model)
        audio_conf = DeepSpeech.get_audio_conf(model)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = int(checkpoint.get('epoch',
                                         0))  # Index start at 0 for training
        loss_results, cer_results = checkpoint['loss_results'], checkpoint[
            'cer_results']
        if args.epochs > loss_results.numel():
            loss_results.resize_(args.epochs)
            cer_results.resize_(args.epochs)
            loss_results[start_epoch:].zero_()
            cer_results[start_epoch:].zero_()
        # Add previous scores to visdom graph
        if args.visdom and loss_results is not None:
            x_axis = epochs[0:start_epoch]
            y_axis = torch.stack(
                (loss_results[0:start_epoch], cer_results[0:start_epoch]),
                dim=1)
            viz_window = viz.line(
                X=x_axis,
                Y=y_axis,
                opts=opts,
            )
    else:
        start_epoch = args.start_epoch
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))
        model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                           nb_layers=args.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[args.rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=not args.look_ahead)
        model = torch.nn.DataParallel(model,
                                      device_ids=[i for i in range(args.nGPU)
                                                  ]).cuda()
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)

    # define loss function (criterion) and decoder
    best_cer = None
    criterion = CTCLoss()
    decoder = GreedyDecoder(labels)

    # define dataloader
    if not args.evaluate:
        train_dataset = SpectrogramDataset(
            audio_conf=audio_conf,
            manifest_filepath=args.train_manifest,
            labels=labels,
            normalize=True,
            augment=args.augment)
        train_sampler = BucketingSampler(train_dataset,
                                         batch_size=args.batch_size)
        train_loader = AudioDataLoader(train_dataset,
                                       num_workers=args.num_workers,
                                       batch_sampler=train_sampler)
        if not args.in_order and start_epoch != 0:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle()
    val_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                     manifest_filepath=args.val_manifest,
                                     labels=labels,
                                     normalize=True,
                                     augment=False)
    val_loader = AudioDataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers)

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    if args.evaluate:
        validate(val_loader, model, decoder, 0)
        return

    for epoch in range(start_epoch, args.epochs):
        avg_loss = train(train_loader, train_sampler, model, criterion,
                         optimizer, epoch)
        cer = validate(val_loader, model, decoder, epoch)

        loss_results[epoch] = avg_loss
        cer_results[epoch] = cer

        adjust_learning_rate(optimizer)

        is_best = False
        if best_cer is None or best_cer > cer:
            print('Found better validated model')
            best_cer = cer
            is_best = True
        save_checkpoint(
            DeepSpeech.serialize(model,
                                 optimizer=optimizer,
                                 epoch=epoch,
                                 loss_results=loss_results,
                                 cer_results=cer_results), is_best, epoch)

        if not args.in_order:
            print("Shuffling batches...")
            train_sampler.shuffle()

        if args.visdom:
            x_axis = epochs[0:epoch + 1]
            y_axis = torch.stack(
                (loss_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1)
            if viz_window is None:
                viz_window = viz.line(
                    X=x_axis,
                    Y=y_axis,
                    opts=opts,
                )
            else:
                viz.line(
                    X=x_axis.unsqueeze(0).expand(y_axis.size(1),
                                                 x_axis.size(0)).transpose(
                                                     0, 1),  # Visdom fix
                    Y=y_axis,
                    win=viz_window,
                    update='replace',
                )
예제 #3
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='LAS')
    parser.add_argument('--model-name', type=str, default='LAS')
    # Dataset
    parser.add_argument('--train-file',
                        type=str,
                        help='data list about train dataset',
                        default='data/ClovaCall/train_ClovaCall.json')
    parser.add_argument('--test-file-list',
                        nargs='*',
                        help='data list about test dataset',
                        default=['data/ClovaCall/test_ClovCall.json'])
    parser.add_argument('--labels-path',
                        default='data/kor_syllable.json',
                        help='Contains large characters over korean')
    parser.add_argument('--dataset-path',
                        default='data/ClovaCall/clean',
                        help='Target dataset path')
    # Hyperparameters
    parser.add_argument('--rnn-type',
                        default='lstm',
                        help='Type of the RNN. rnn|gru|lstm are supported')
    parser.add_argument('--encoder_layers',
                        type=int,
                        default=3,
                        help='number of layers of model (default: 3)')
    parser.add_argument('--encoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--decoder_layers',
                        type=int,
                        default=2,
                        help='number of pyramidal layers (default: 2)')
    parser.add_argument('--decoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.3,
                        help='Dropout rate in training (default: 0.3)')
    parser.add_argument(
        '--no-bidirectional',
        dest='bidirectional',
        action='store_false',
        default=True,
        help='Turn off bi-directional RNNs, introduces lookahead convolution')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size in training (default: 32)')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=4,
        help='Number of workers in dataset loader (default: 4)')
    parser.add_argument('--num_gpu',
                        type=int,
                        default=1,
                        help='Number of gpus (default: 1)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='Number of max epochs in training (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=3e-4,
                        help='Learning rate (default: 3e-4)')
    parser.add_argument('--learning-anneal',
                        default=1.1,
                        type=float,
                        help='Annealing learning rate every epoch')
    parser.add_argument('--teacher_forcing',
                        type=float,
                        default=1.0,
                        help='Teacher forcing ratio in decoder (default: 1.0)')
    parser.add_argument('--max_len',
                        type=int,
                        default=80,
                        help='Maximum characters of sentence (default: 80)')
    parser.add_argument('--max-norm',
                        default=400,
                        type=int,
                        help='Norm cutoff to prevent explosion of gradients')
    # Audio Config
    parser.add_argument('--sample-rate',
                        default=16000,
                        type=int,
                        help='Sampling Rate')
    parser.add_argument('--window-size',
                        default=.02,
                        type=float,
                        help='Window size for spectrogram')
    parser.add_argument('--window-stride',
                        default=.01,
                        type=float,
                        help='Window stride for spectrogram')
    # System
    parser.add_argument('--save-folder',
                        default='models',
                        help='Location to save epoch models')
    parser.add_argument('--model-path',
                        default='models/las_final.pth',
                        help='Location to save best validation model')
    parser.add_argument(
        '--log-path',
        default='log/',
        help='path to predict log about valid and test dataset')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=123456,
                        help='random seed (default: 123456)')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='Train or Test')
    parser.add_argument('--load-model',
                        action='store_true',
                        default=False,
                        help='Load model')
    parser.add_argument('--finetune',
                        dest='finetune',
                        action='store_true',
                        default=False,
                        help='Finetune the model after load model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    char2index, index2char = label_loader.load_label_json(args.labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if args.cuda else 'cpu')

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride)

    # Batch Size
    batch_size = args.batch_size * args.num_gpu

    print(">> Train dataset : ", args.train_file)
    trainData_list = []
    with open(args.train_file, 'r', encoding='utf-8') as f:
        trainData_list = json.load(f)

    if args.num_gpu != 1:
        last_batch = len(trainData_list) % batch_size
        if last_batch != 0 and last_batch < args.num_gpu:
            trainData_list = trainData_list[:-last_batch]

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       dataset_path=args.dataset_path,
                                       data_list=trainData_list,
                                       char2index=char2index,
                                       sos_id=SOS_token,
                                       eos_id=EOS_token,
                                       normalize=True)

    train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    print(">> Test dataset : ", args.test_file_list)
    testLoader_dict = {}
    for test_file in args.test_file_list:
        testData_list = []
        with open(test_file, 'r', encoding='utf-8') as f:
            testData_list = json.load(f)

        test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                          dataset_path=args.dataset_path,
                                          data_list=testData_list,
                                          char2index=char2index,
                                          sos_id=SOS_token,
                                          eos_id=EOS_token,
                                          normalize=True)
        testLoader_dict[test_file] = AudioDataLoader(
            test_dataset, batch_size=1, num_workers=args.num_workers)

    input_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
                     dropout_p=args.dropout,
                     bidirectional=args.bidirectional,
                     rnn_cell=args.rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     args.max_len,
                     args.decoder_size,
                     args.encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=args.decoder_layers,
                     rnn_cell=args.rnn_type,
                     dropout_p=args.dropout,
                     bidirectional_encoder=args.bidirectional)

    model = Seq2Seq(enc, dec)

    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)

    optim_state = None
    if args.load_model:  # Starting from previous model
        print("Loading checkpoint model %s" % args.model_path)
        state = torch.load(args.model_path)
        model.load_state_dict(state['model'])
        print('Model loaded')

        if not args.finetune:  # Just load model
            optim_state = state['optimizer']

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    if optim_state is not None:
        optimizer.load_state_dict(optim_state)

    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    train_model = nn.DataParallel(model)

    if args.mode != "train":
        for test_file in args.test_file_list:
            test_loader = testLoader_dict[test_file]
            test_loss, test_cer, transcripts_list = evaluate(model,
                                                             test_loader,
                                                             criterion,
                                                             device,
                                                             save_output=True)

            for idx, line in enumerate(transcripts_list):
                # print(line)
                hyp, ref = line.split('\t')
                print("({:3d}/{:3d}) [REF]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       ref))
                print("({:3d}/{:3d}) [HYP]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       hyp))
                print()

            print("Test {} CER : {}".format(test_file, test_cer))
    else:
        best_cer = 1e10
        begin_epoch = 0

        # start_time = time.time()
        start_time = datetime.datetime.now()

        for epoch in range(begin_epoch, args.epochs):
            train_loss, train_cer = train(train_model, train_loader, criterion,
                                          optimizer, device, epoch,
                                          train_sampler, args.max_norm,
                                          args.teacher_forcing)

            # end_time = time.time()
            # elapsed_time = end_time - start_time
            elapsed_time = datetime.datetime.now() - start_time

            train_log = 'Train({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\tTime {time:}'.format(
                epoch + 1,
                name='train',
                loss=train_loss,
                cer=train_cer,
                time=elapsed_time)
            print(train_log)

            cer_list = []
            for test_file in args.test_file_list:
                test_loader = testLoader_dict[test_file]
                test_loss, test_cer, _ = evaluate(model,
                                                  test_loader,
                                                  criterion,
                                                  device,
                                                  save_output=False)
                test_log = 'Test({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\t'.format(
                    epoch + 1, name=test_file, loss=test_loss, cer=test_cer)
                print(test_log)

                cer_list.append(test_cer)

            if best_cer > cer_list[0]:
                print("Found better validated model, saving to %s" %
                      args.model_path)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, args.model_path)
                best_cer = cer_list[0]

            print("Shuffling batches...")
            train_sampler.shuffle(epoch)

            for g in optimizer.param_groups:
                g['lr'] = g['lr'] / args.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))
예제 #4
0
def main():
    args = parser.parse_args()
    save_folder = args.save_folder

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if args.visdom:
        from visdom import Visdom
        viz = Visdom()

        opts = [
            dict(title=args.visdom_id + ' Loss', ylabel='Loss',
                 xlabel='Epoch'),
            dict(title=args.visdom_id + ' WER', ylabel='WER', xlabel='Epoch'),
            dict(title=args.visdom_id + ' CER', ylabel='CER', xlabel='Epoch')
        ]

        viz_windows = [None, None, None]
        epochs = torch.arange(1, args.epochs + 1)
    if args.tensorboard:
        from logger import TensorBoardLogger
        try:
            os.makedirs(args.log_dir)
        except OSError as e:
            if e.errno == errno.EEXIST:
                print('Directory already exists.')
                for file in os.listdir(args.log_dir):
                    file_path = os.path.join(args.log_dir, file)
                    try:
                        if os.path.isfile(file_path):
                            os.unlink(file_path)
                    except Exception as e:
                        raise
            else:
                raise
        logger = TensorBoardLogger(args.log_dir)

    try:
        os.makedirs(save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    criterion = CTCLoss()

    with open(args.labels_path) as label_file:
        # labels = str(''.join(json.load(label_file)))
        labels = json.load(label_file)
    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True,
                                      augment=False)
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=True)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True)
    # decoder = GreedyDecoder(labels)
    decoder = GreedyDecoder(labels,
                            space_index=labels.index('<space>'),
                            blank_index=labels.index('_'))

    if args.continue_from:
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from)
        model.load_state_dict(package['state_dict'])
        optimizer.load_state_dict(package['optim_dict'])
        start_epoch = int(package.get(
            'epoch', 1)) - 1  # Python index start at 0 for training
        start_iter = package.get('iteration', None)
        if start_iter is None:
            start_epoch += 1  # Assume that we saved a model after an epoch finished, so start at the next epoch.
            start_iter = 0
        else:
            start_iter += 1
        avg_loss = int(package.get('avg_loss', 0))
        loss_results, cer_results, wer_results = package[
            'loss_results'], package['cer_results'], package['wer_results']
        if args.visdom and \
                        package['loss_results'] is not None and start_epoch > 0:  # Add previous scores to visdom graph
            x_axis = epochs[0:start_epoch]
            y_axis = [
                loss_results[0:start_epoch], wer_results[0:start_epoch],
                cer_results[0:start_epoch]
            ]
            for x in range(len(viz_windows)):
                viz_windows[x] = viz.line(
                    X=x_axis,
                    Y=y_axis[x],
                    opts=opts[x],
                )
        if args.tensorboard and \
                        package['loss_results'] is not None and start_epoch > 0:  # Previous scores to tensorboard logs
            for i in range(start_epoch):
                info = {
                    'Avg Train Loss': loss_results[i],
                    'Avg WER': wer_results[i],
                    'Avg CER': cer_results[i]
                }
                for tag, val in info.items():
                    logger.scalar_summary(tag, val, i + 1)
        if not args.no_bucketing:
            print("Using bucketing sampler for the following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler
    else:
        avg_loss = 0
        start_epoch = 0
        start_iter = 0
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        end = time.time()
        for i, (data) in enumerate(train_loader, start=start_iter):
            if i == len(train_loader):
                break
            inputs, targets, input_percentages, target_sizes = data
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = Variable(inputs, requires_grad=False)
            target_sizes = Variable(target_sizes, requires_grad=False)
            targets = Variable(targets, requires_grad=False)

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH

            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int(),
                             requires_grad=False)

            loss = criterion(out, targets, sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            if args.cuda:
                torch.cuda.synchronize()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (i + 1),
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))
            if args.checkpoint_per_batch > 0 and i > 0 and (
                    i + 1) % args.checkpoint_per_batch == 0:
                file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (
                    save_folder, epoch + 1, i + 1)
                print("Saving checkpoint model to %s" % file_path)
                torch.save(
                    DeepSpeech.serialize(model,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         iteration=i,
                                         loss_results=loss_results,
                                         wer_results=wer_results,
                                         cer_results=cer_results,
                                         avg_loss=avg_loss), file_path)
            del loss
            del out
        avg_loss /= len(train_loader)

        print('Training Summary Epoch: [{0}]\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, loss=avg_loss))

        start_iter = 0  # Reset start iteration for next epoch
        total_cer, total_wer = 0, 0
        model.eval()
        for i, (data) in enumerate(test_loader):  # test
            inputs, targets, input_percentages, target_sizes = data

            inputs = Variable(inputs, volatile=True)

            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH
            seq_length = out.size(0)
            sizes = input_percentages.mul_(int(seq_length)).int()

            decoded_output = decoder.decode(out.data, sizes)
            target_strings = decoder.process_strings(
                decoder.convert_to_strings(split_targets))
            wer, cer = 0, 0
            for x in range(len(target_strings)):
                # wer += decoder.wer(decoded_output[x], target_strings[x]) / float(len(target_strings[x].split()))
                # cer += decoder.cer(decoded_output[x], target_strings[x]) / float(len(target_strings[x]))
                wer += decoder.wer(
                    decoded_output[x], target_strings[x]) / float(
                        len(target_strings[x].replace(' ',
                                                      '').split('<space>')))
                cer += decoder.cer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x].split(' ')))
            total_cer += cer
            total_wer += wer

            if args.cuda:
                torch.cuda.synchronize()
            del out
        wer = total_wer / len(test_loader.dataset)
        cer = total_cer / len(test_loader.dataset)
        wer *= 100
        cer *= 100
        loss_results[epoch] = avg_loss
        wer_results[epoch] = wer
        cer_results[epoch] = cer
        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if args.visdom:
            # epoch += 1
            x_axis = epochs[0:epoch + 1]
            y_axis = [
                loss_results[0:epoch + 1], wer_results[0:epoch + 1],
                cer_results[0:epoch + 1]
            ]
            for x in range(len(viz_windows)):
                if viz_windows[x] is None:
                    viz_windows[x] = viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        opts=opts[x],
                    )
                else:
                    viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        win=viz_windows[x],
                        update='replace',
                    )
        if args.tensorboard:
            info = {'Avg Train Loss': avg_loss, 'Avg WER': wer, 'Avg CER': cer}
            for tag, val in info.items():
                logger.scalar_summary(tag, val, epoch + 1)
            if args.log_params:
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    logger.histo_summary(tag, to_np(value), epoch + 1)
                    logger.histo_summary(tag + '/grad', to_np(value.grad),
                                         epoch + 1)
        if args.checkpoint:
            file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results), file_path)
        # anneal lr
        optim_state = optimizer.state_dict()
        optim_state['param_groups'][0][
            'lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
        optimizer.load_state_dict(optim_state)
        print('Learning rate annealed to: {lr:.6f}'.format(
            lr=optim_state['param_groups'][0]['lr']))

        if best_wer is None or best_wer > wer:
            print("Found better validated model, saving to %s" %
                  args.model_path)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results), args.model_path)
            best_wer = wer

        avg_loss = 0
        if not args.no_bucketing and epoch == 0:
            print("Switching to bucketing sampler for following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler
예제 #5
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='LAS')
    parser.add_argument('--model-name', type=str, default='LAS')
    # Dataset
    parser.add_argument('--test-file-list',
                        nargs='*',
                        help='data list about test dataset',
                        default=['data/Youtube/clean'])
    parser.add_argument('--labels-path',
                        default='data/kor_syllable.json',
                        help='Contains large characters over korean')
    parser.add_argument('--dataset-path',
                        default='data/Youtube/clean',
                        help='Target dataset path')

    # Hyperparameters
    parser.add_argument('--rnn-type',
                        default='lstm',
                        help='Type of the RNN. rnn|gru|lstm are supported')
    parser.add_argument('--encoder_layers',
                        type=int,
                        default=3,
                        help='number of layers of model (default: 3)')
    parser.add_argument('--encoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--decoder_layers',
                        type=int,
                        default=2,
                        help='number of pyramidal layers (default: 2)')
    parser.add_argument('--decoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument(
        '--no-bidirectional',
        dest='bidirectional',
        action='store_false',
        default=True,
        help='Turn off bi-directional RNNs, introduces lookahead convolution')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=4,
        help='Number of workers in dataset loader (default: 4)')
    parser.add_argument('--num_gpu',
                        type=int,
                        default=1,
                        help='Number of gpus (default: 1)')
    parser.add_argument('--learning-anneal',
                        default=1.1,
                        type=float,
                        help='Annealing learning rate every epoch')
    parser.add_argument('--max_len',
                        type=int,
                        default=80,
                        help='Maximum characters of sentence (default: 80)')
    parser.add_argument('--max-norm',
                        default=400,
                        type=int,
                        help='Norm cutoff to prevent explosion of gradients')

    # Audio Config
    parser.add_argument('--sample-rate',
                        default=16000,
                        type=int,
                        help='Sampling Rate')
    parser.add_argument('--window-size',
                        default=.02,
                        type=float,
                        help='Window size for spectrogram')
    parser.add_argument('--window-stride',
                        default=.01,
                        type=float,
                        help='Window stride for spectrogram')

    # System
    parser.add_argument('--model-path',
                        default='models/final.pth',
                        help='Location to save best validation model')
    parser.add_argument('--seed',
                        type=int,
                        default=123456,
                        help='random seed (default: 123456)')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    char2index, index2char = label_loader.load_label_json(args.labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride)

    print(">> Test dataset : ", args.dataset_path)

    testData_list = []

    for wav_name in sorted(glob(f'{args.dataset_path}/*'),
                           key=lambda i: int(os.path.basename(i)[:-4])):
        wav_dict = {}
        wav_dict['wav'] = os.path.basename(wav_name)
        wav_dict['text'] = os.path.basename(wav_name)
        wav_dict['speaker_id'] = '0'
        testData_list.append(wav_dict)

    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      dataset_path=args.dataset_path,
                                      data_list=testData_list,
                                      char2index=char2index,
                                      sos_id=SOS_token,
                                      eos_id=EOS_token,
                                      normalize=True)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers)

    input_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
                     bidirectional=args.bidirectional,
                     rnn_cell=args.rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     args.max_len,
                     args.decoder_size,
                     args.encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=args.decoder_layers,
                     rnn_cell=args.rnn_type,
                     bidirectional_encoder=args.bidirectional)

    model = Seq2Seq(enc, dec)

    print("Loading checkpoint model %s" % args.model_path)
    state = torch.load(args.model_path)
    model.load_state_dict(state['model'])

    model = model.to(device)
    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    test_loss, test_cer, transcripts_list = evaluate(model,
                                                     test_loader,
                                                     criterion,
                                                     device,
                                                     save_output=True)

    print(f"{'true':^20} | {'pred':^20}")
    for line in transcripts_list:
        print(line)

    print("Test {} CER : {}".format("test", test_cer))
예제 #6
0
def main():
    os.chdir(os.path.dirname(__file__))
    os.chdir('../')

    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    model_name = 'LAS'
    # Dataset
    test_file_list = ['data/Youtube_test/youtube_test.json']
    labels_path = 'data/kor_syllable.json'
    dataset_path = 'data/Youtube_test'

    # Hyperparameters
    rnn_type = 'lstm'
    encoder_layers = 3
    encoder_size = 512
    decoder_layers = 2
    decoder_size = 512
    dropout = 0.3
    bidirectional = True
    num_workers = 4
    max_len = 80

    # Audio Config
    sample_rate = 16000
    window_size = .02
    window_stride = .01

    # System
    save_folder = 'models'
    model_path = 'models/AIHub_train/LSTM_512x3_512x2_AIHub_train/final.pth'
    cuda = True
    seed = 123456

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    char2index, index2char = label_loader.load_label_json(labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if cuda else 'cpu')

    audio_conf = dict(sample_rate=sample_rate,
                      window_size=window_size,
                      window_stride=window_stride)


    print(">> Test dataset : ", test_file_list)
    testLoader_dict = {}


    for test_file in test_file_list:
        testData_list = []
        with open(test_file, 'r', encoding='utf-8') as f:
            testData_list = json.load(f)
        
        test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                          dataset_path=dataset_path, 
                                          data_list=testData_list,
                                          char2index=char2index, sos_id=SOS_token, eos_id=EOS_token,
                                          normalize=True)
        testLoader_dict[test_file] = AudioDataLoader(test_dataset, batch_size=1, num_workers=num_workers)


    input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
    enc = EncoderRNN(input_size, encoder_size, n_layers=encoder_layers,
                     dropout_p=dropout, bidirectional=bidirectional, 
                     rnn_cell=rnn_type, variable_lengths=False)

    dec = DecoderRNN(len(char2index), max_len, decoder_size, encoder_size,
                     SOS_token, EOS_token,
                     n_layers=decoder_layers, rnn_cell=rnn_type, 
                     dropout_p=dropout, bidirectional_encoder=bidirectional)


    model = Seq2Seq(enc, dec)
    os.makedirs(save_folder, exist_ok=True)


    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Loading checkpoint model %s" % model_path)
    state = torch.load(model_path)
    model.load_state_dict(state['model'])
    print('Model loaded')

    model = model.to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    for test_file in test_file_list:
        test_loader = testLoader_dict[test_file]
        test_loss, test_cer, transcripts_list = evaluate(model, test_loader, criterion, device, save_output=True)

        for line in transcripts_list:
            print('STT : ' + line.split('\t')[0])
            print('정답 : ' + line.split('\t')[1])
            print('-'*100)

        print("Test {} CER : {}".format(test_file, test_cer))
예제 #7
0
def wav_to_text():
    os.chdir(os.path.dirname(__file__))
    os.chdir('../')

    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    model_name = 'LAS'
    # Dataset
    target_file_lst = glob.glob('data/youtube/*/chunks/*.pcm')
    target_dic = []
    for chunk in target_file_lst:
        target_dic.append({'wav': chunk, 'text': '1'})
    with open('data/youtube/target_file.json', 'w',
              encoding='utf-8') as json_file:
        json.dump(target_dic, json_file)
    target_file_list = ['data/youtube/target_file.json']
    labels_path = 'data/kor_syllable.json'
    dataset_path = ''

    # Hyperparameters
    rnn_type = 'lstm'
    encoder_layers = 3
    encoder_size = 512
    decoder_layers = 2
    decoder_size = 512
    dropout = 0.3
    bidirectional = True
    num_workers = 4
    max_len = 80

    # Audio Config
    sample_rate = 16000
    window_size = .02
    window_stride = .01

    # System
    save_folder = 'models'
    model_path = 'models/news_finetune/final.pth'
    cuda = True
    seed = 123456

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    char2index, index2char = label_loader.load_label_json(labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if cuda else 'cpu')

    audio_conf = dict(sample_rate=sample_rate,
                      window_size=window_size,
                      window_stride=window_stride)

    print(">> Target dataset : ", target_file_list)
    targetLoader_dict = {}

    for target_file in target_file_list:
        targetData_list = []
        with open(target_file, 'r', encoding='utf-8') as f:
            targetData_list = json.load(f)

        target_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                            dataset_path=dataset_path,
                                            data_list=targetData_list,
                                            char2index=char2index,
                                            sos_id=SOS_token,
                                            eos_id=EOS_token,
                                            normalize=True)
        targetLoader_dict[target_file] = AudioDataLoader(
            target_dataset, batch_size=1, num_workers=num_workers)

    input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     encoder_size,
                     n_layers=encoder_layers,
                     dropout_p=dropout,
                     bidirectional=bidirectional,
                     rnn_cell=rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     max_len,
                     decoder_size,
                     encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=decoder_layers,
                     rnn_cell=rnn_type,
                     dropout_p=dropout,
                     bidirectional_encoder=bidirectional)

    model = Seq2Seq(enc, dec)
    os.makedirs(save_folder, exist_ok=True)

    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("Loading checkpoint model %s" % model_path)
    state = torch.load(model_path)
    model.load_state_dict(state['model'])
    print('Model loaded')

    model = model.to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    result = []
    for target_file in target_file_list:
        target_loader = targetLoader_dict[target_file]
        transcripts_list = evaluate(model,
                                    target_loader,
                                    criterion,
                                    device,
                                    save_output=True)
        result.append(' '.join(
            [line.split('\t')[0] for line in transcripts_list]))

    lst = os.listdir('data/youtube')
    for x in lst:
        if x != 'target_file.json':
            shutil.rmtree('data/youtube/' + x)

    return result