def instantiate_network(self, args): with open(args.labels_path) as label_file: labels = str(''.join(json.load(label_file))) audio_conf = dict(sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, noise_dir=args.noise_dir, noise_prob=args.noise_prob, noise_levels=(args.noise_min, args.noise_max)) rnn_type = args.rnn_type.lower() assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru" self.net = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, labels=labels, rnn_type=supported_rnns[rnn_type], audio_conf=audio_conf, bidirectional=args.bidirectional) self.init_optimizer(args) self.net.decoder = GreedyDecoder(labels) self.net.audio_conf = audio_conf self.net.labels = labels self.net.device = self.device self.net = self.net.to(self.device) return self.net
def main(args): alphabet = alphabet_factory() device = torch.device('cpu') checkpoint = torch.load('model_best.pth', map_location=device) in_features = args.n_mfcc * (2 * args.n_context + 1) model = build_deepspeech(in_features=in_features, num_classes=len(alphabet)) model.load_state_dict(checkpoint['state_dict']) print_size_of_model(model) decoder = GreedyDecoder() if args.quantize: model = torch.quantization.quantize_dynamic(model, {nn.RNN, nn.Linear}, dtype=torch.qint8) logging.info('quantized model') print_size_of_model(model) transform = prepare_transformations(args) dataset = ProcessedDataset(get_dataset(args.datadir, "dev-clean"), transform, alphabet) collate_fn = collate_factory(model_length_function) criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank]) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn, drop_last=False) test_loop_fn(dataloader, model, criterion, device, 1, decoder, alphabet)
def main(index, args): alphabet = alphabet_factory() train_dataset, test_dataset = split_dataset(args, alphabet) collate_fn = collate_factory(model_length_function) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, #pin_memory=True, shuffle=True, collate_fn=collate_fn, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True) # Get loss function, optimizer, and model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') in_features = args.n_mfcc * (2 * args.n_context + 1) model = build_deepspeech(in_features=in_features, num_classes=len(alphabet)) model = model.to(device) logging.info("Number of parameters: %s", count_parameters(model)) optimizer = get_optimizer(args, model.parameters()) criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank]) decoder = GreedyDecoder() train_eval_fn(args.num_epochs, train_loader, test_loader, optimizer, model, criterion, device, decoder, alphabet, args.checkpoint, args.log_steps)
def distribute(self, args): labels = self.net.labels audio_conf = self.net.audio_conf main_proc = self.initialize_process_group(args) if main_proc and self._visdom: # Add previous scores to visdom graph self._visdom_logger.load_previous_values(self._start_epoch, self._package) if main_proc and self._tensorboard: # Previous scores to tensorboard logs self._tensorboard_logger.load_previous_values(self._start_epoch, self._package) self.net = DistributedDataParallel(self.net) self.net.decoder = GreedyDecoder(labels) self.net.audio_conf = audio_conf self.net.labels = labels self.net.device = self.device self.net = self.net.to(self.device) return main_proc
def continue_from(self, args): print("Loading checkpoint model %s" % args.continue_from) package = torch.load(args.continue_from, map_location=lambda storage, loc: storage) self.net = DeepSpeech.load_model_package(package) self.init_optimizer(args) self._optim_state = package['optim_dict'] self._amp_state = package['amp'] self._optimizer.load_state_dict(self._optim_state) amp.load_state_dict(self._amp_state) self._start_epoch = int(package.get('epoch', 1)) - 1 # Index start at 0 for training self._start_iter = package.get('iteration', None) if self._start_iter is None: self._start_epoch += 1 # We saved model after epoch finished, start at the next epoch. self._start_iter = 0 else: self._start_iter += 1 self._avg_loss = int(package.get('avg_loss', 0)) loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], \ package['wer_results'] for k, (loss, cer, wer) in enumerate(zip(loss_results, cer_results, wer_results)): try: self._loss_results[k], self._cer_results[k], self._wer_results[k] = loss, cer, wer except IndexError: break if self._start_epoch>0: self._best_cer = min(cer_results[:self._start_epoch]) self._best_wer = min(wer_results[:self._start_epoch]) self._cer = cer_results[self._start_epoch - 1] self._wer = wer_results[self._start_epoch - 1] else: self._best_cer = None self._best_wer = None self._cer = None self._wer = None self._package = package self.net.decoder = GreedyDecoder(self.net.labels) self.net.device = self.device return self.net
def _main_xla(index, args): import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl alphabet = alphabet_factory() train_dataset, test_dataset = split_dataset(args, alphabet) collate_fn = collate_factory(model_length_function) if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) else: train_sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True) # Scale learning rate to world size lr = args.learning_rate * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() model = build_deepspeech(in_features=in_features, num_classes=len(alphabet)) model = model.to(device) optimizer = get_optimizer(args, model.parameters()) criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank]) decoder = GreedyDecoder() train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) class XLAProxyOptimizer: """ XLA Proxy optimizer for compatibility with torch.Optimizer """ def __init__(self, optimizer): self.optimizer = optimizer def zero_grad(self): self.optimizer.zero_grad() def step(self): xm.optimizer_step(self.optimizer) optimizer = XLAProxyOptimizer(optimizer) train_eval_fn(args.num_epochs, train_device_loader, test_device_loader, optimizer, model, criterion, device, decoder, alphabet, args.checkpoint)