def build_ctc_model(T, B): charmap = data.CharMap() return models.CTCModel(charmap, n_mels=80, nhidden_rnn=185, nlayers_rnn=3, cell_type='GRU', dropout=0.1)
def train(args): """ Training of the algorithm """ logger = logging.getLogger(__name__) logger.info("Training") use_cuda = torch.cuda.is_available() device = torch.device('cuda') if use_cuda else torch.device('cpu') # Data loading loaders = data.get_dataloaders(args.datasetroot, args.datasetversion, cuda=use_cuda, batch_size=args.batch_size, n_threads=args.nthreads, min_duration=args.min_duration, max_duration=args.max_duration, small_experiment=args.debug, train_augment=args.train_augment, nmels=args.nmels, logger=logger) train_loader, valid_loader, test_loader = loaders # Parameters n_mels = args.nmels nhidden_rnn = args.nhidden_rnn nlayers_rnn = args.nlayers_rnn cell_type = args.cell_type dropout = args.dropout base_lr = args.base_lr num_epochs = args.num_epochs grad_clip = args.grad_clip # We need the char map to know about the vocabulary size charmap = data.CharMap() blank_id = charmap.blankid # Model definition ########################### #### START CODING HERE #### ########################### model = None ########################## #### STOP CODING HERE #### ########################## decode = model.decode model.to(device) # Loss, optimizer baseloss = nn.CTCLoss(blank=blank_id, reduction='mean', zero_infinity=True) loss = lambda *params: baseloss(*wrap_ctc_args(*params)) ########################### #### START CODING HERE #### ########################### optimizer = None ########################## #### STOP CODING HERE #### ########################## metrics = {'CTC': loss} # Callbacks summary_text = "## Summary of the model architecture\n" + \ f"{deepcs.display.torch_summarize(model)}\n" summary_text += "\n\n## Executed command :\n" +\ "{}".format(" ".join(sys.argv)) summary_text += "\n\n## Args : \n {}".format(args) logger.info(summary_text) logdir = generate_unique_logpath('./logs', 'ctc') tensorboard_writer = SummaryWriter(log_dir=logdir, flush_secs=5) tensorboard_writer.add_text("Experiment summary", deepcs.display.htmlize(summary_text)) with open(os.path.join(logdir, "summary.txt"), 'w') as f: f.write(summary_text) logger.info(f">>>>> Results saved in {logdir}") model_checkpoint = ModelCheckpoint(model, os.path.join(logdir, 'best_model.pt')) scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # Training loop for e in range(num_epochs): ftrain(model, train_loader, loss, optimizer, device, metrics, grad_clip=grad_clip, num_model_args=1, num_epoch=e, tensorboard_writer=tensorboard_writer) # Compute and record the metrics on the validation set valid_metrics = ftest(model, valid_loader, device, metrics, num_model_args=1) better_model = model_checkpoint.update(valid_metrics['CTC']) scheduler.step() logger.info("[%d/%d] Validation: CTCLoss : %.3f %s" % (e, num_epochs, valid_metrics['CTC'], "[>> BETTER <<]" if better_model else "")) for m_name, m_value in valid_metrics.items(): tensorboard_writer.add_scalar(f'metrics/valid_{m_name}', m_value, e + 1) # Compute and record the metrics on the test set test_metrics = ftest(model, test_loader, device, metrics, num_model_args=1) logger.info("[%d/%d] Test: Loss : %.3f " % (e, num_epochs, test_metrics['CTC'])) for m_name, m_value in test_metrics.items(): tensorboard_writer.add_scalar(f'metrics/test_{m_name}', m_value, e + 1) # Try to decode some of the validation samples model.eval() valid_decodings = decode_samples(decode, valid_loader, n=2, device=device, charmap=charmap) train_decodings = decode_samples(decode, train_loader, n=2, device=device, charmap=charmap) decoding_results = "## Decoding results on the training set\n" decoding_results += train_decodings decoding_results += "## Decoding results on the validation set\n" decoding_results += valid_decodings tensorboard_writer.add_text("Decodings", deepcs.display.htmlize(decoding_results), global_step=e + 1) logger.info("\n" + decoding_results)
def test(args): """ Test function to decode a sample with a pretrained model """ import matplotlib.pyplot as plt logger = logging.getLogger(__name__) logger.info("Test") use_cuda = torch.cuda.is_available() device = torch.device('cuda') if use_cuda else torch.device('cpu') # We need the char map to know about the vocabulary size charmap = data.CharMap() # Create the model # It is required to build up the same architecture than the one # used during training. If you do not remember the parameters # check the summary.txt file in the logdir where you have you # modelpath pt file saved. A better way to handle that # would be to use yaml files containing the hyperparameters for # training and load this yaml file when loading. n_mels = args.nmels nhidden_rnn = args.nhidden_rnn nlayers_rnn = args.nlayers_rnn cell_type = args.cell_type dropout = args.dropout modelpath = args.modelpath audiofile = args.audiofile beamwidth = args.beamwidth beamsearch = args.beamsearch assert (modelpath is not None) assert (audiofile is not None) logger.info("Building the model") model = models.CTCModel(charmap, n_mels, nhidden_rnn, nlayers_rnn, cell_type, dropout) model.to(device) model.load_state_dict(torch.load(modelpath)) # Switch the model to eval mode model.eval() # Load and preprocess the audiofile logger.info("Loading and preprocessing the audio file") waveform, sample_rate = torchaudio.load(audiofile) waveform = torchaudio.transforms.Resample( sample_rate, data._DEFAULT_RATE)(waveform).transpose(0, 1) # (T, B) # Hardcoded normalization, this is dirty, I agree spectro_normalization = (-31, 32) # The processor for computing the spectrogram waveform_processor = data.WaveformProcessor( data._DEFAULT_RATE, data._DEFAULT_WIN_LENGTH * 1e-3, data._DEFAULT_WIN_STEP * 1e-3, n_mels, False, spectro_normalization) spectrogram = waveform_processor(waveform).to(device) spectro_length = spectrogram.shape[0] # Plot the spectrogram logger.info("Plotting the spectrogram") fig = plt.figure() ax = fig.add_subplot() ax.imshow(spectrogram[0].cpu().numpy(), aspect='equal', cmap='magma', origin='lower') ax.set_xlabel("Mel scale") ax.set_ylabel("Time (sample)") fig.tight_layout() plt.savefig("spectro_test.png") spectrogram = pack_padded_sequence(spectrogram, lengths=[spectro_length]) logger.info("Decoding the spectrogram") if beamsearch: likely_sequences = model.beam_decode(spectrogram, beamwidth, charmap.blankid) else: likely_sequences = model.decode(spectrogram) print("Log prob Sequence\n") print("\n".join( ["{:.2f} {}".format(p, s) for (p, s) in likely_sequences]))