示例#1
0
class create_model(nn.Module):
    def __init__(self, args):
        super(create_model, self).__init__()
        self.args = args
        self.model = SpeechNet(args)
        self.model.to(args.device)
        self.criterion = nn.CTCLoss()
        self.decoder = BeamCTCDecoder(PHONEME_MAP, blank_index=0, beam_width=args.beam_width)

        self.state_names = ['loss', 'edit_dist', 'lr']

    def train_setup(self):
        self.lr = self.args.lr
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        if self.args.use_step_schedule:
            self.scheduler = MultiStepLR(self.optimizer, milestones=self.args.decay_steps, gamma=self.args.lr_gamma)
        elif self.args.use_reduce_schedule:
            self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=1)
        else:
            self.scheduler = ParamScheduler(self.optimizer, scale_cos, self.args.num_epochs * self.args.loader_length)
#         self.model.apply(weights_init)
        self.model.train()

    def optimize_parameters(self, input, input_lens, target, target_lens):
        input, target = input.to(self.args.device), target.to(self.args.device)
        output, output_lens, self.loss = self.forward(input, input_lens, target, target_lens)

        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()

        self.edit_dist = self.get_edit_dist(output, output_lens, target, target_lens)

        del input
        del target
        del input_lens
        del target_lens
        del output
        del output_lens

    def update_learning_rate(self, dist=None):
        if self.args.use_reduce_schedule:
            self.scheduler.step(dist)
        else:
            self.scheduler.step()
        self.lr = self.optimizer.param_groups[0]['lr']

    def get_current_states(self):
        errors_ret = OrderedDict()
        for name in self.state_names:
            if isinstance(name, str):
                # float(...) works for both scalar tensor and float number
                errors_ret[name] = float(getattr(self, name))
        return errors_ret

    def get_edit_dist(self, output, output_lens, target, target_lens):
        output, target = output.cpu(), target.cpu()
        phonome_preds = self.decoder.decode(output, output_lens)
        phonomes = self.decoder.convert_to_strings(target, target_lens)
        edit_dist = np.sum(
            [self.decoder.Lev_dist(phonome_pred, phonome) for (phonome_pred, phonome) in zip(phonome_preds, phonomes)])
        return edit_dist

    def forward(self, input, input_lens, target=None, target_lens=None, is_training=True):
        output, output_lens = self.model(input, input_lens)
        if is_training:
            # The official documentation is your best friend: https://pytorch.org/docs/stable/nn.html#ctcloss
            # nn.CTCLoss takes 4 arguments to compute the loss:
            # [log_probs]: Prediction of your model at each time step. Shape: (seq_len, batch_size, vocab_size)
            # Values must be log probabilities. Neither probabilities nor logits will work.
            # Make sure the output of your network is log probabilities, by adding a nn.LogSoftmax after the last layer.
            # [targets]: The ground truth sequences. Shape: (batch_size, seq_len)
            # Values are indices of phonemes. Again, remember that index 0 is reserved for "blank"
            # [input_lengths]: Lengths of sequences in log_probs. Shape: (batch_size,).
            # This is not necessarily the same as lengths of input of the model.
            # [target_lengths]: Lengths of sequences in targets. Shape: (batch_size,).
            loss = self.criterion(output.permute(1, 0, 2), target, input_lens, target_lens)
            return output, output_lens, loss
        else:
            return output, output_lens,

    def train(self):
        try:
            self.model.train()
        except:
            print('train() cannot be implemented as model does not exist.')

    def eval(self):
        try:
            self.model.eval()
        except:
            print('eval() cannot be implemented as model does not exist.')

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def save_model(self, which_epoch):
        save_filename = '%s_net.pth' % (which_epoch)
        save_path = os.path.join(self.args.expr_dir, save_filename)
        if torch.cuda.is_available():
            try:
                torch.save(self.model.module.cpu().state_dict(), save_path)
            except:
                torch.save(self.model.cpu().state_dict(), save_path)
        else:
            torch.save(self.model.cpu().state_dict(), save_path)

        self.model.to(self.args.device)
示例#2
0
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        if args.cuda:
            inputs = inputs.cuda()

        out = model(inputs)
        out = out.transpose(0, 1)  # TxNxH
        seq_length = out.size(0)
        sizes = input_percentages.mul_(int(seq_length)).int()

        decoded_output = decoder.decode(out.data, sizes)
        target_strings = decoder.process_strings(
            decoder.convert_to_strings(split_targets))
        wer, cer = 0, 0
        for x in range(len(target_strings)):
            wer += decoder.wer(decoded_output[x], target_strings[x]) / float(
                len(target_strings[x].split()))
            cer += decoder.cer(decoded_output[x], target_strings[x]) / float(
                len(target_strings[x]))
        total_cer += cer
        total_wer += wer

    wer = total_wer / len(test_loader.dataset)
    cer = total_cer / len(test_loader.dataset)

    print('Test Summary \t'
          'Average WER {wer:.3f}\t'
          'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
示例#3
0
        # unflatten targets
        split_targets = []
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        if args.cuda:
            inputs = inputs.cuda()

        out = model(inputs)
        out = out.transpose(0, 1)  # TxNxH
        seq_length = out.size(0)
        sizes = input_percentages.mul_(int(seq_length)).int()

        decoded_output = decoder.decode(out.data, sizes)
        target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets))
        wer, cer = 0, 0
        for x in range(len(target_strings)):
            wer += decoder.wer(decoded_output[x], target_strings[x]) / float(len(target_strings[x].split()))
            cer += decoder.cer(decoded_output[x], target_strings[x]) / float(len(target_strings[x]))
        total_cer += cer
        total_wer += wer

    wer = total_wer / len(test_loader.dataset)
    cer = total_cer / len(test_loader.dataset)

    print('Test Summary \t'
          'Average WER {wer:.3f}\t'
          'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))