Esempio n. 1
0
    def instantiate_network(self, args):
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))

        rnn_type = args.rnn_type.lower()
        assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
        self.net =  DeepSpeech(rnn_hidden_size=args.hidden_size,
                          nb_layers=args.hidden_layers,
                          labels=labels,
                          rnn_type=supported_rnns[rnn_type],
                          audio_conf=audio_conf,
                          bidirectional=args.bidirectional)
        self.init_optimizer(args)
        self.net.decoder = GreedyDecoder(labels)
        self.net.audio_conf = audio_conf
        self.net.labels = labels
        self.net.device = self.device
        self.net = self.net.to(self.device)
        return self.net
Esempio n. 2
0
def main(args):
    alphabet = alphabet_factory()
    device = torch.device('cpu')
    checkpoint = torch.load('model_best.pth', map_location=device)
    in_features = args.n_mfcc * (2 * args.n_context + 1)
    model = build_deepspeech(in_features=in_features,
                             num_classes=len(alphabet))
    model.load_state_dict(checkpoint['state_dict'])
    print_size_of_model(model)
    decoder = GreedyDecoder()
    if args.quantize:
        model = torch.quantization.quantize_dynamic(model, {nn.RNN, nn.Linear},
                                                    dtype=torch.qint8)
        logging.info('quantized model')
        print_size_of_model(model)

    transform = prepare_transformations(args)
    dataset = ProcessedDataset(get_dataset(args.datadir, "dev-clean"),
                               transform, alphabet)
    collate_fn = collate_factory(model_length_function)
    criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             collate_fn=collate_fn,
                                             drop_last=False)
    test_loop_fn(dataloader, model, criterion, device, 1, decoder, alphabet)
Esempio n. 3
0
def main(index, args):
    alphabet = alphabet_factory()
    train_dataset, test_dataset = split_dataset(args, alphabet)
    collate_fn = collate_factory(model_length_function)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        #pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=args.num_workers,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              collate_fn=collate_fn,
                                              drop_last=True)

    # Get loss function, optimizer, and model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_features = args.n_mfcc * (2 * args.n_context + 1)
    model = build_deepspeech(in_features=in_features,
                             num_classes=len(alphabet))
    model = model.to(device)
    logging.info("Number of parameters: %s", count_parameters(model))

    optimizer = get_optimizer(args, model.parameters())
    criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank])
    decoder = GreedyDecoder()
    train_eval_fn(args.num_epochs, train_loader, test_loader, optimizer, model,
                  criterion, device, decoder, alphabet, args.checkpoint,
                  args.log_steps)
Esempio n. 4
0
    def distribute(self, args):
        labels = self.net.labels
        audio_conf = self.net.audio_conf
        main_proc = self.initialize_process_group(args)
        if main_proc and self._visdom:  # Add previous scores to visdom graph
            self._visdom_logger.load_previous_values(self._start_epoch, self._package)
        if main_proc and self._tensorboard:  # Previous scores to tensorboard logs
            self._tensorboard_logger.load_previous_values(self._start_epoch, self._package)
        self.net = DistributedDataParallel(self.net)

        self.net.decoder = GreedyDecoder(labels)
        self.net.audio_conf = audio_conf
        self.net.labels = labels
        self.net.device = self.device
        self.net = self.net.to(self.device)
        return main_proc
Esempio n. 5
0
    def continue_from(self, args):
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from, map_location=lambda storage, loc: storage)
        self.net = DeepSpeech.load_model_package(package)
        self.init_optimizer(args)
        self._optim_state = package['optim_dict']
        self._amp_state = package['amp']
        self._optimizer.load_state_dict(self._optim_state)
        amp.load_state_dict(self._amp_state)
        self._start_epoch = int(package.get('epoch', 1)) - 1  # Index start at 0 for training
        self._start_iter = package.get('iteration', None)
        if self._start_iter is None:
            self._start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
            self._start_iter = 0
        else:
            self._start_iter += 1
        self._avg_loss = int(package.get('avg_loss', 0))
        loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], \
                                                 package['wer_results']
        for k, (loss, cer, wer) in enumerate(zip(loss_results, cer_results, wer_results)):
            try:
                self._loss_results[k], self._cer_results[k], self._wer_results[k] = loss, cer, wer
            except IndexError:
                break

        if self._start_epoch>0:
            self._best_cer = min(cer_results[:self._start_epoch])
            self._best_wer = min(wer_results[:self._start_epoch])
            self._cer = cer_results[self._start_epoch - 1]
            self._wer = wer_results[self._start_epoch - 1]
        else:
            self._best_cer = None
            self._best_wer = None
            self._cer = None
            self._wer = None
        self._package = package
        self.net.decoder = GreedyDecoder(self.net.labels)
        self.net.device = self.device
        return self.net
Esempio n. 6
0
def _main_xla(index, args):
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

    alphabet = alphabet_factory()
    train_dataset, test_dataset = split_dataset(args, alphabet)
    collate_fn = collate_factory(model_length_function)
    if xm.xrt_world_size() > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=collate_fn,
                                              drop_last=True)

    # Scale learning rate to world size
    lr = args.learning_rate * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = build_deepspeech(in_features=in_features,
                             num_classes=len(alphabet))
    model = model.to(device)
    optimizer = get_optimizer(args, model.parameters())
    criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank])
    decoder = GreedyDecoder()

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)

    class XLAProxyOptimizer:
        """
        XLA Proxy optimizer for compatibility with
        torch.Optimizer
        """
        def __init__(self, optimizer):
            self.optimizer = optimizer

        def zero_grad(self):
            self.optimizer.zero_grad()

        def step(self):
            xm.optimizer_step(self.optimizer)

    optimizer = XLAProxyOptimizer(optimizer)

    train_eval_fn(args.num_epochs, train_device_loader, test_device_loader,
                  optimizer, model, criterion, device, decoder, alphabet,
                  args.checkpoint)