def build_ctc_model(T, B):
    charmap = data.CharMap()

    return models.CTCModel(charmap,
                           n_mels=80,
                           nhidden_rnn=185,
                           nlayers_rnn=3,
                           cell_type='GRU',
                           dropout=0.1)
示例#2
0
def test(args):
    """
    Test function to decode a sample with a pretrained model
    """
    import matplotlib.pyplot as plt

    logger = logging.getLogger(__name__)
    logger.info("Test")

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # We need the char map to know about the vocabulary size
    charmap = data.CharMap()

    # Create the model
    # It is required to build up the same architecture than the one
    # used during training. If you do not remember the parameters
    # check the summary.txt file in the logdir where you have you
    # modelpath pt file saved. A better way to handle that
    # would be to use yaml files containing the hyperparameters for
    # training and load this yaml file when loading.
    n_mels = args.nmels
    nhidden_rnn = args.nhidden_rnn
    nlayers_rnn = args.nlayers_rnn
    cell_type = args.cell_type
    dropout = args.dropout

    modelpath = args.modelpath
    audiofile = args.audiofile
    beamwidth = args.beamwidth
    beamsearch = args.beamsearch
    assert (modelpath is not None)
    assert (audiofile is not None)

    logger.info("Building the model")
    model = models.CTCModel(charmap, n_mels, nhidden_rnn, nlayers_rnn,
                            cell_type, dropout)
    model.to(device)
    model.load_state_dict(torch.load(modelpath))

    # Switch the model to eval mode
    model.eval()

    # Load and preprocess the audiofile
    logger.info("Loading and preprocessing the audio file")
    waveform, sample_rate = torchaudio.load(audiofile)
    waveform = torchaudio.transforms.Resample(
        sample_rate, data._DEFAULT_RATE)(waveform).transpose(0, 1)  # (T, B)
    # Hardcoded normalization, this is dirty, I agree
    spectro_normalization = (-31, 32)
    # The processor for computing the spectrogram
    waveform_processor = data.WaveformProcessor(
        data._DEFAULT_RATE, data._DEFAULT_WIN_LENGTH * 1e-3,
        data._DEFAULT_WIN_STEP * 1e-3, n_mels, False, spectro_normalization)
    spectrogram = waveform_processor(waveform).to(device)
    spectro_length = spectrogram.shape[0]

    # Plot the spectrogram
    logger.info("Plotting the spectrogram")
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.imshow(spectrogram[0].cpu().numpy(),
              aspect='equal',
              cmap='magma',
              origin='lower')
    ax.set_xlabel("Mel scale")
    ax.set_ylabel("Time (sample)")
    fig.tight_layout()
    plt.savefig("spectro_test.png")

    spectrogram = pack_padded_sequence(spectrogram, lengths=[spectro_length])

    logger.info("Decoding the spectrogram")

    if beamsearch:
        likely_sequences = model.beam_decode(spectrogram, beamwidth,
                                             charmap.blankid)
    else:
        likely_sequences = model.decode(spectrogram)

    print("Log prob    Sequence\n")
    print("\n".join(
        ["{:.2f}      {}".format(p, s) for (p, s) in likely_sequences]))
示例#3
0
def train(args):
    """
    Training of the algorithm
    """
    logger = logging.getLogger(__name__)
    logger.info("Training")

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    # Data loading
    loaders = data.get_dataloaders(args.datasetroot,
                                   args.datasetversion,
                                   cuda=use_cuda,
                                   batch_size=args.batch_size,
                                   n_threads=args.nthreads,
                                   min_duration=args.min_duration,
                                   max_duration=args.max_duration,
                                   small_experiment=args.debug,
                                   train_augment=args.train_augment,
                                   nmels=args.nmels,
                                   logger=logger)
    train_loader, valid_loader, test_loader = loaders

    # Parameters
    n_mels = args.nmels
    nhidden_rnn = args.nhidden_rnn
    nlayers_rnn = args.nlayers_rnn
    cell_type = args.cell_type
    dropout = args.dropout
    base_lr = args.base_lr
    num_epochs = args.num_epochs
    grad_clip = args.grad_clip

    # We need the char map to know about the vocabulary size
    charmap = data.CharMap()
    blank_id = charmap.blankid

    # Model definition
    ###########################
    #### START CODING HERE ####
    ###########################
    #@TEMPL@model = None
    #@SOL
    model = models.CTCModel(charmap, n_mels, nhidden_rnn, nlayers_rnn,
                            cell_type, dropout)
    #SOL@
    ##########################
    #### STOP CODING HERE ####
    ##########################

    decode = model.decode

    model.to(device)

    # Loss, optimizer
    baseloss = nn.CTCLoss(blank=blank_id, reduction='mean', zero_infinity=True)
    loss = lambda *params: baseloss(*wrap_ctc_args(*params))

    ###########################
    #### START CODING HERE ####
    ###########################
    #@TEMPL@optimizer = None
    optimizer = optim.Adam(model.parameters(), lr=base_lr)  #@SOL@
    ##########################
    #### STOP CODING HERE ####
    ##########################

    metrics = {'CTC': loss}

    # Callbacks
    summary_text = "## Summary of the model architecture\n" + \
            f"{deepcs.display.torch_summarize(model)}\n"
    summary_text += "\n\n## Executed command :\n" +\
        "{}".format(" ".join(sys.argv))
    summary_text += "\n\n## Args : \n {}".format(args)

    logger.info(summary_text)

    logdir = generate_unique_logpath('./logs', 'ctc')
    tensorboard_writer = SummaryWriter(log_dir=logdir, flush_secs=5)
    tensorboard_writer.add_text("Experiment summary",
                                deepcs.display.htmlize(summary_text))

    with open(os.path.join(logdir, "summary.txt"), 'w') as f:
        f.write(summary_text)

    logger.info(f">>>>> Results saved in {logdir}")

    model_checkpoint = ModelCheckpoint(model,
                                       os.path.join(logdir, 'best_model.pt'))
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # Training loop
    for e in range(num_epochs):
        ftrain(model,
               train_loader,
               loss,
               optimizer,
               device,
               metrics,
               grad_clip=grad_clip,
               num_model_args=1,
               num_epoch=e,
               tensorboard_writer=tensorboard_writer)

        # Compute and record the metrics on the validation set
        valid_metrics = ftest(model,
                              valid_loader,
                              device,
                              metrics,
                              num_model_args=1)
        better_model = model_checkpoint.update(valid_metrics['CTC'])
        scheduler.step()

        logger.info("[%d/%d] Validation:   CTCLoss : %.3f %s" %
                    (e, num_epochs, valid_metrics['CTC'],
                     "[>> BETTER <<]" if better_model else ""))

        for m_name, m_value in valid_metrics.items():
            tensorboard_writer.add_scalar(f'metrics/valid_{m_name}', m_value,
                                          e + 1)
        # Compute and record the metrics on the test set
        test_metrics = ftest(model,
                             test_loader,
                             device,
                             metrics,
                             num_model_args=1)
        logger.info("[%d/%d] Test:   Loss : %.3f " %
                    (e, num_epochs, test_metrics['CTC']))
        for m_name, m_value in test_metrics.items():
            tensorboard_writer.add_scalar(f'metrics/test_{m_name}', m_value,
                                          e + 1)
        # Try to decode some of the validation samples
        model.eval()
        valid_decodings = decode_samples(decode,
                                         valid_loader,
                                         n=2,
                                         device=device,
                                         charmap=charmap)
        train_decodings = decode_samples(decode,
                                         train_loader,
                                         n=2,
                                         device=device,
                                         charmap=charmap)

        decoding_results = "## Decoding results on the training set\n"
        decoding_results += train_decodings
        decoding_results += "## Decoding results on the validation set\n"
        decoding_results += valid_decodings
        tensorboard_writer.add_text("Decodings",
                                    deepcs.display.htmlize(decoding_results),
                                    global_step=e + 1)
        logger.info("\n" + decoding_results)