Example #1
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='LAS')
    parser.add_argument('--model-name', type=str, default='LAS')
    # Dataset
    parser.add_argument('--train-file',
                        type=str,
                        help='data list about train dataset',
                        default='data/ClovaCall/train_ClovaCall.json')
    parser.add_argument('--test-file-list',
                        nargs='*',
                        help='data list about test dataset',
                        default=['data/ClovaCall/test_ClovCall.json'])
    parser.add_argument('--labels-path',
                        default='data/kor_syllable.json',
                        help='Contains large characters over korean')
    parser.add_argument('--dataset-path',
                        default='data/ClovaCall/clean',
                        help='Target dataset path')
    # Hyperparameters
    parser.add_argument('--rnn-type',
                        default='lstm',
                        help='Type of the RNN. rnn|gru|lstm are supported')
    parser.add_argument('--encoder_layers',
                        type=int,
                        default=3,
                        help='number of layers of model (default: 3)')
    parser.add_argument('--encoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--decoder_layers',
                        type=int,
                        default=2,
                        help='number of pyramidal layers (default: 2)')
    parser.add_argument('--decoder_size',
                        type=int,
                        default=512,
                        help='hidden size of model (default: 512)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.3,
                        help='Dropout rate in training (default: 0.3)')
    parser.add_argument(
        '--no-bidirectional',
        dest='bidirectional',
        action='store_false',
        default=True,
        help='Turn off bi-directional RNNs, introduces lookahead convolution')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size in training (default: 32)')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=4,
        help='Number of workers in dataset loader (default: 4)')
    parser.add_argument('--num_gpu',
                        type=int,
                        default=1,
                        help='Number of gpus (default: 1)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='Number of max epochs in training (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=3e-4,
                        help='Learning rate (default: 3e-4)')
    parser.add_argument('--learning-anneal',
                        default=1.1,
                        type=float,
                        help='Annealing learning rate every epoch')
    parser.add_argument('--teacher_forcing',
                        type=float,
                        default=1.0,
                        help='Teacher forcing ratio in decoder (default: 1.0)')
    parser.add_argument('--max_len',
                        type=int,
                        default=80,
                        help='Maximum characters of sentence (default: 80)')
    parser.add_argument('--max-norm',
                        default=400,
                        type=int,
                        help='Norm cutoff to prevent explosion of gradients')
    # Audio Config
    parser.add_argument('--sample-rate',
                        default=16000,
                        type=int,
                        help='Sampling Rate')
    parser.add_argument('--window-size',
                        default=.02,
                        type=float,
                        help='Window size for spectrogram')
    parser.add_argument('--window-stride',
                        default=.01,
                        type=float,
                        help='Window stride for spectrogram')
    # System
    parser.add_argument('--save-folder',
                        default='models',
                        help='Location to save epoch models')
    parser.add_argument('--model-path',
                        default='models/las_final.pth',
                        help='Location to save best validation model')
    parser.add_argument(
        '--log-path',
        default='log/',
        help='path to predict log about valid and test dataset')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=123456,
                        help='random seed (default: 123456)')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='Train or Test')
    parser.add_argument('--load-model',
                        action='store_true',
                        default=False,
                        help='Load model')
    parser.add_argument('--finetune',
                        dest='finetune',
                        action='store_true',
                        default=False,
                        help='Finetune the model after load model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    char2index, index2char = label_loader.load_label_json(args.labels_path)
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']

    device = torch.device('cuda' if args.cuda else 'cpu')

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride)

    # Batch Size
    batch_size = args.batch_size * args.num_gpu

    print(">> Train dataset : ", args.train_file)
    trainData_list = []
    with open(args.train_file, 'r', encoding='utf-8') as f:
        trainData_list = json.load(f)

    if args.num_gpu != 1:
        last_batch = len(trainData_list) % batch_size
        if last_batch != 0 and last_batch < args.num_gpu:
            trainData_list = trainData_list[:-last_batch]

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       dataset_path=args.dataset_path,
                                       data_list=trainData_list,
                                       char2index=char2index,
                                       sos_id=SOS_token,
                                       eos_id=EOS_token,
                                       normalize=True)

    train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    print(">> Test dataset : ", args.test_file_list)
    testLoader_dict = {}
    for test_file in args.test_file_list:
        testData_list = []
        with open(test_file, 'r', encoding='utf-8') as f:
            testData_list = json.load(f)

        test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                          dataset_path=args.dataset_path,
                                          data_list=testData_list,
                                          char2index=char2index,
                                          sos_id=SOS_token,
                                          eos_id=EOS_token,
                                          normalize=True)
        testLoader_dict[test_file] = AudioDataLoader(
            test_dataset, batch_size=1, num_workers=args.num_workers)

    input_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1)
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
                     dropout_p=args.dropout,
                     bidirectional=args.bidirectional,
                     rnn_cell=args.rnn_type,
                     variable_lengths=False)

    dec = DecoderRNN(len(char2index),
                     args.max_len,
                     args.decoder_size,
                     args.encoder_size,
                     SOS_token,
                     EOS_token,
                     n_layers=args.decoder_layers,
                     rnn_cell=args.rnn_type,
                     dropout_p=args.dropout,
                     bidirectional_encoder=args.bidirectional)

    model = Seq2Seq(enc, dec)

    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)

    optim_state = None
    if args.load_model:  # Starting from previous model
        print("Loading checkpoint model %s" % args.model_path)
        state = torch.load(args.model_path)
        model.load_state_dict(state['model'])
        print('Model loaded')

        if not args.finetune:  # Just load model
            optim_state = state['optimizer']

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    if optim_state is not None:
        optimizer.load_state_dict(optim_state)

    criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

    print(model)
    print("Number of parameters: %d" % Seq2Seq.get_param_size(model))

    train_model = nn.DataParallel(model)

    if args.mode != "train":
        for test_file in args.test_file_list:
            test_loader = testLoader_dict[test_file]
            test_loss, test_cer, transcripts_list = evaluate(model,
                                                             test_loader,
                                                             criterion,
                                                             device,
                                                             save_output=True)

            for idx, line in enumerate(transcripts_list):
                # print(line)
                hyp, ref = line.split('\t')
                print("({:3d}/{:3d}) [REF]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       ref))
                print("({:3d}/{:3d}) [HYP]: {}".format(idx + 1,
                                                       len(transcripts_list),
                                                       hyp))
                print()

            print("Test {} CER : {}".format(test_file, test_cer))
    else:
        best_cer = 1e10
        begin_epoch = 0

        # start_time = time.time()
        start_time = datetime.datetime.now()

        for epoch in range(begin_epoch, args.epochs):
            train_loss, train_cer = train(train_model, train_loader, criterion,
                                          optimizer, device, epoch,
                                          train_sampler, args.max_norm,
                                          args.teacher_forcing)

            # end_time = time.time()
            # elapsed_time = end_time - start_time
            elapsed_time = datetime.datetime.now() - start_time

            train_log = 'Train({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\tTime {time:}'.format(
                epoch + 1,
                name='train',
                loss=train_loss,
                cer=train_cer,
                time=elapsed_time)
            print(train_log)

            cer_list = []
            for test_file in args.test_file_list:
                test_loader = testLoader_dict[test_file]
                test_loss, test_cer, _ = evaluate(model,
                                                  test_loader,
                                                  criterion,
                                                  device,
                                                  save_output=False)
                test_log = 'Test({name}) Summary Epoch: [{0}]\tAverage Loss {loss:.3f}\tAverage CER {cer:.3f}\t'.format(
                    epoch + 1, name=test_file, loss=test_loss, cer=test_cer)
                print(test_log)

                cer_list.append(test_cer)

            if best_cer > cer_list[0]:
                print("Found better validated model, saving to %s" %
                      args.model_path)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, args.model_path)
                best_cer = cer_list[0]

            print("Shuffling batches...")
            train_sampler.shuffle(epoch)

            for g in optimizer.param_groups:
                g['lr'] = g['lr'] / args.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))
Example #2
0
def main():
    global args, train_logger, test_logger
    args = options.parse_args()
    os.makedirs(args.log_dir)
    test_logger = Logger(os.path.join(args.log_dir, 'test.log'))
    with open(os.path.join(args.log_dir, 'config.log'), 'w') as f:
        f.write(args.config_str)
    if not args.evaluate:
        os.makedirs(args.checkpoint_dir)
        train_logger = Logger(os.path.join(args.log_dir, 'train.log'))
    loss_results, cer_results = torch.FloatTensor(
        args.epochs), torch.FloatTensor(args.epochs)

    if args.visdom:
        from visdom import Visdom
        viz = Visdom()
        opts = dict(title=args.experiment_id,
                    ylabel='',
                    xlabel='Epoch',
                    legend=['Loss', 'CER'])
        viz_windows = None
        epochs = torch.arange(0, args.epochs)

    if args.resume:
        print('Loading checkpoint model %s' % args.resume)
        checkpoint = torch.load(args.resume)
        model = DeepSpeech.load_model_checkpoint(checkpoint)
        model = torch.nn.DataParallel(model,
                                      device_ids=[i for i in range(args.nGPU)
                                                  ]).cuda()
        labels = DeepSpeech.get_labels(model)
        audio_conf = DeepSpeech.get_audio_conf(model)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = int(checkpoint.get('epoch',
                                         0))  # Index start at 0 for training
        loss_results, cer_results = checkpoint['loss_results'], checkpoint[
            'cer_results']
        if args.epochs > loss_results.numel():
            loss_results.resize_(args.epochs)
            cer_results.resize_(args.epochs)
            loss_results[start_epoch:].zero_()
            cer_results[start_epoch:].zero_()
        # Add previous scores to visdom graph
        if args.visdom and loss_results is not None:
            x_axis = epochs[0:start_epoch]
            y_axis = torch.stack(
                (loss_results[0:start_epoch], cer_results[0:start_epoch]),
                dim=1)
            viz_window = viz.line(
                X=x_axis,
                Y=y_axis,
                opts=opts,
            )
    else:
        start_epoch = args.start_epoch
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))
        model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                           nb_layers=args.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[args.rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=not args.look_ahead)
        model = torch.nn.DataParallel(model,
                                      device_ids=[i for i in range(args.nGPU)
                                                  ]).cuda()
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True)

    # define loss function (criterion) and decoder
    best_cer = None
    criterion = CTCLoss()
    decoder = GreedyDecoder(labels)

    # define dataloader
    if not args.evaluate:
        train_dataset = SpectrogramDataset(
            audio_conf=audio_conf,
            manifest_filepath=args.train_manifest,
            labels=labels,
            normalize=True,
            augment=args.augment)
        train_sampler = BucketingSampler(train_dataset,
                                         batch_size=args.batch_size)
        train_loader = AudioDataLoader(train_dataset,
                                       num_workers=args.num_workers,
                                       batch_sampler=train_sampler)
        if not args.in_order and start_epoch != 0:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle()
    val_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                     manifest_filepath=args.val_manifest,
                                     labels=labels,
                                     normalize=True,
                                     augment=False)
    val_loader = AudioDataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers)

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    if args.evaluate:
        validate(val_loader, model, decoder, 0)
        return

    for epoch in range(start_epoch, args.epochs):
        avg_loss = train(train_loader, train_sampler, model, criterion,
                         optimizer, epoch)
        cer = validate(val_loader, model, decoder, epoch)

        loss_results[epoch] = avg_loss
        cer_results[epoch] = cer

        adjust_learning_rate(optimizer)

        is_best = False
        if best_cer is None or best_cer > cer:
            print('Found better validated model')
            best_cer = cer
            is_best = True
        save_checkpoint(
            DeepSpeech.serialize(model,
                                 optimizer=optimizer,
                                 epoch=epoch,
                                 loss_results=loss_results,
                                 cer_results=cer_results), is_best, epoch)

        if not args.in_order:
            print("Shuffling batches...")
            train_sampler.shuffle()

        if args.visdom:
            x_axis = epochs[0:epoch + 1]
            y_axis = torch.stack(
                (loss_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1)
            if viz_window is None:
                viz_window = viz.line(
                    X=x_axis,
                    Y=y_axis,
                    opts=opts,
                )
            else:
                viz.line(
                    X=x_axis.unsqueeze(0).expand(y_axis.size(1),
                                                 x_axis.size(0)).transpose(
                                                     0, 1),  # Visdom fix
                    Y=y_axis,
                    win=viz_window,
                    update='replace',
                )
Example #3
0
def get_data_loader(manifest_file_path, labels, sample_rate, window_size,
                    window_stride, batch_size):
    dataset = SpectrogramDataset(labels, sample_rate, window_size,
                                 window_stride, manifest_file_path)
    sampler = BucketingSampler(dataset, batch_size=batch_size)
    return DataLoader(dataset, batch_sampler=sampler)