class create_model(nn.Module): def __init__(self, args): super(create_model, self).__init__() self.args = args self.model = SpeechNet(args) self.model.to(args.device) self.criterion = nn.CTCLoss() self.decoder = BeamCTCDecoder(PHONEME_MAP, blank_index=0, beam_width=args.beam_width) self.state_names = ['loss', 'edit_dist', 'lr'] def train_setup(self): self.lr = self.args.lr self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) if self.args.use_step_schedule: self.scheduler = MultiStepLR(self.optimizer, milestones=self.args.decay_steps, gamma=self.args.lr_gamma) elif self.args.use_reduce_schedule: self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=1) else: self.scheduler = ParamScheduler(self.optimizer, scale_cos, self.args.num_epochs * self.args.loader_length) # self.model.apply(weights_init) self.model.train() def optimize_parameters(self, input, input_lens, target, target_lens): input, target = input.to(self.args.device), target.to(self.args.device) output, output_lens, self.loss = self.forward(input, input_lens, target, target_lens) self.optimizer.zero_grad() self.loss.backward() self.optimizer.step() self.edit_dist = self.get_edit_dist(output, output_lens, target, target_lens) del input del target del input_lens del target_lens del output del output_lens def update_learning_rate(self, dist=None): if self.args.use_reduce_schedule: self.scheduler.step(dist) else: self.scheduler.step() self.lr = self.optimizer.param_groups[0]['lr'] def get_current_states(self): errors_ret = OrderedDict() for name in self.state_names: if isinstance(name, str): # float(...) works for both scalar tensor and float number errors_ret[name] = float(getattr(self, name)) return errors_ret def get_edit_dist(self, output, output_lens, target, target_lens): output, target = output.cpu(), target.cpu() phonome_preds = self.decoder.decode(output, output_lens) phonomes = self.decoder.convert_to_strings(target, target_lens) edit_dist = np.sum( [self.decoder.Lev_dist(phonome_pred, phonome) for (phonome_pred, phonome) in zip(phonome_preds, phonomes)]) return edit_dist def forward(self, input, input_lens, target=None, target_lens=None, is_training=True): output, output_lens = self.model(input, input_lens) if is_training: # The official documentation is your best friend: https://pytorch.org/docs/stable/nn.html#ctcloss # nn.CTCLoss takes 4 arguments to compute the loss: # [log_probs]: Prediction of your model at each time step. Shape: (seq_len, batch_size, vocab_size) # Values must be log probabilities. Neither probabilities nor logits will work. # Make sure the output of your network is log probabilities, by adding a nn.LogSoftmax after the last layer. # [targets]: The ground truth sequences. Shape: (batch_size, seq_len) # Values are indices of phonemes. Again, remember that index 0 is reserved for "blank" # [input_lengths]: Lengths of sequences in log_probs. Shape: (batch_size,). # This is not necessarily the same as lengths of input of the model. # [target_lengths]: Lengths of sequences in targets. Shape: (batch_size,). loss = self.criterion(output.permute(1, 0, 2), target, input_lens, target_lens) return output, output_lens, loss else: return output, output_lens, def train(self): try: self.model.train() except: print('train() cannot be implemented as model does not exist.') def eval(self): try: self.model.eval() except: print('eval() cannot be implemented as model does not exist.') def load_model(self, model_path): self.model.load_state_dict(torch.load(model_path)) def save_model(self, which_epoch): save_filename = '%s_net.pth' % (which_epoch) save_path = os.path.join(self.args.expr_dir, save_filename) if torch.cuda.is_available(): try: torch.save(self.model.module.cpu().state_dict(), save_path) except: torch.save(self.model.cpu().state_dict(), save_path) else: torch.save(self.model.cpu().state_dict(), save_path) self.model.to(self.args.device)
offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size if args.cuda: inputs = inputs.cuda() out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings( decoder.convert_to_strings(split_targets)) wer, cer = 0, 0 for x in range(len(target_strings)): wer += decoder.wer(decoded_output[x], target_strings[x]) / float( len(target_strings[x].split())) cer += decoder.cer(decoded_output[x], target_strings[x]) / float( len(target_strings[x])) total_cer += cer total_wer += wer wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) print('Test Summary \t' 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
# unflatten targets split_targets = [] offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size if args.cuda: inputs = inputs.cuda() out = model(inputs) out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) sizes = input_percentages.mul_(int(seq_length)).int() decoded_output = decoder.decode(out.data, sizes) target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets)) wer, cer = 0, 0 for x in range(len(target_strings)): wer += decoder.wer(decoded_output[x], target_strings[x]) / float(len(target_strings[x].split())) cer += decoder.cer(decoded_output[x], target_strings[x]) / float(len(target_strings[x])) total_cer += cer total_wer += wer wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) print('Test Summary \t' 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))