def init_train_set(epoch, from_iter): #train_dataset.set_curriculum_epoch(epoch, sample=True) train_dataset.set_curriculum_epoch(epoch, sample=False) global train_loader, train_sampler if not args.distributed: train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) train_sampler.bins = train_sampler.bins[from_iter:] else: train_sampler = DistributedBucketingSampler( train_dataset, batch_size=args.batch_size, num_replicas=args.world_size, rank=args.rank) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) if (not args.no_shuffle and epoch != 0) or args.no_sorta_grad: print("Shuffling batches for the following epochs") train_sampler.shuffle(epoch)
else: train_sampler = DistributedBucketingSampler( train_dataset, batch_size=args.batch_size, num_replicas=args.world_size, rank=args.rank) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad: print("Shuffling batches for the following epochs") train_sampler.shuffle(start_epoch) try: model.load_state_dict(torch.load(args.weights)['state_dict'], strict=False) print('using weights') except: print('not using weighs') model = model.to(device) parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, nesterov=True, weight_decay=1e-5) if optim_state is not None:
test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, normalize=True, augment=False) train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) if not args.no_shuffle and start_epoch != 0: print("Shuffling batches for the following epochs") train_sampler.shuffle() if args.cuda: model = torch.nn.DataParallel(model).cuda() print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() for epoch in range(start_epoch, args.epochs): model.train() end = time.time() for i, (data) in enumerate(train_loader, start=start_iter):
args = parser.parse_args() if __name__ == '__main__': torch.set_grad_enabled(False) model, _ = load_model(args.model_path) device = torch.device("cuda" if args.cuda else "cpu") label_decoder = LabelDecoder(model.labels) model.eval() model = model.to(device) test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.test_manifest, labels=model.labels) test_sampler = BucketingSampler(test_dataset, batch_size=args.batch_size) test_loader = AudioDataLoader(test_dataset, batch_sampler=test_sampler, num_workers=args.num_workers) test_sampler.shuffle(1) total_wer, total_cer, total_ler, num_words, num_chars, num_labels = 0, 0, 0, 0, 0, 0 output_data = [] for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader), ascii=True): inputs, targets, input_sizes, target_sizes, filenames = data inputs = inputs.to(device) input_sizes = input_sizes.to(device) outputs = model.transcribe(inputs, input_sizes) for i, target in enumerate(targets): reference = label_decoder.decode(target[:target_sizes[i]].tolist()) transcript = label_decoder.decode(outputs[i]) wer, trans_words, ref_words = calculate_wer(transcript, reference, '\t') cer, trans_chars, ref_chars = calculate_cer(transcript, reference, '\t')
def main(): opt = TrainOptions().parse() device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu") #import fake_opt #opt = fake_opt.Asr_train() visualizer = Visualizer(opt) logging = visualizer.get_logger() acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png') loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'], 'loss.png') # data logging.info("Building dataset.") # train目录 和 dict目录,作为输入 train_dataset = SequentialDataset( opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'), ) val_dataset = SequentialDataset( opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'), ) train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) train_loader = SequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler) val_loader = SequentialDataLoader(val_dataset, batch_size=int(opt.batch_size / 2), num_workers=opt.num_workers, shuffle=False) opt.idim = train_dataset.get_feat_size() opt.odim = train_dataset.get_num_classes() opt.char_list = train_dataset.get_char_list() opt.train_dataset_len = len(train_dataset) logging.info('#input dims : ' + str(opt.idim)) logging.info('#output dims: ' + str(opt.odim)) logging.info("Dataset ready!") # Setup a model asr_model = E2E(opt) fbank_model = FbankModel(opt) lr = opt.lr # default=0.005 eps = opt.eps # default=1e-8 iters = opt.iters # default=0 start_epoch = opt.start_epoch # default=0 best_loss = opt.best_loss # default=float('inf') best_acc = opt.best_acc # default=0 # 如果有中继点 if opt.resume: model_path = os.path.join(opt.works_dir, opt.resume) if os.path.isfile(model_path): package = torch.load(model_path, map_location=lambda storage, loc: storage) lr = package.get('lr', opt.lr) eps = package.get('eps', opt.eps) best_loss = package.get('best_loss', float('inf')) best_acc = package.get('best_acc', 0) start_epoch = int(package.get('epoch', 0)) iters = int(package.get('iters', 0)) acc_report = package.get('acc_report', acc_report) loss_report = package.get('loss_report', loss_report) visualizer.set_plot_report(acc_report, 'acc.png') visualizer.set_plot_report(loss_report, 'loss.png') asr_model = E2E.load_model(model_path, 'asr_state_dict') fbank_model = FbankModel.load_model(model_path, 'fbank_state_dict') logging.info('Loading model {} and iters {}'.format( model_path, iters)) else: print("no checkpoint found at {}".format(model_path)) # convert to cuda asr_model.cuda() fbank_model.cuda() print(asr_model) print(fbank_model) # Setup an optimizer parameters = filter( lambda p: p.requires_grad, itertools.chain(asr_model.parameters(), fbank_model.parameters())) #parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters())) if opt.opt_type == 'adadelta': optimizer = torch.optim.Adadelta(parameters, rho=0.95, eps=eps) elif opt.opt_type == 'adam': optimizer = torch.optim.Adam(parameters, lr=lr, betas=(opt.beta1, 0.999)) asr_model.train() fbank_model.train() #NOTE sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate) sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_epoch, opt.sche_samp_final_epoch, opt.sche_samp_final_rate) sche_samp_rate = sample_rampup.update(iters) # 计算fbank的cmvn输入 fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy') if os.path.exists(fbank_cmvn_file): # 如果有fbank_cmvn fbank_cmvn = np.load(fbank_cmvn_file) else: # 否则自己生成 for i, (data) in enumerate(train_loader, start=0): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes) # 下面这个if 是原code,通过fbank-cmvn是否为none判断是否break是十分愚蠢的 if fbank_cmvn is not None: np.save(fbank_cmvn_file, fbank_cmvn) print('save fbank_cmvn to {}'.format(fbank_cmvn_file)) break # 因此需要通过 cmvn_processed_num 和 cmvn_num 来判断 if fbank_model.cmvn_processed_num >= fbank_model.cmvn_num: # 运行最后一次compute_cmvn fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes) np.save(fbank_cmvn_file, fbank_cmvn) print('save fbank_cmvn to {}'.format(fbank_cmvn_file)) break print(fbank_model.cmvn_processed_num) # 3944 fbank_cmvn = torch.FloatTensor(fbank_cmvn) # 开始训练 for epoch in range(start_epoch, opt.epochs): if epoch > opt.shuffle_epoch: print("Shuffling batches for the following epochs") train_sampler.shuffle(epoch) for i, (data) in enumerate(train_loader, start=(iters * opt.batch_size) % len(train_dataset)): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data fbank_features = fbank_model(inputs, fbank_cmvn) # NOTE 下面这个原来的语句,是和 变量data 不匹配的 # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data # asr_model 输出的数量是3,而这里却有4个变量 # 去查以下e2e_model # 实际在forward中的输出根本没有context # 另外,下面另外一个asr_model 同理 loss_ctc, loss_att, acc = asr_model(fbank_features, targets, input_sizes, target_sizes, sche_samp_rate) loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att optimizer.zero_grad() # Clear the parameter gradients loss.backward() # compute backwards # compute the gradient norm to check if it is normal or not 'fbank_state_dict': fbank_model.state_dict(), grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() iters += 1 errors = { 'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 'train/acc': acc, 'train/loss_att': loss_att.item() } visualizer.set_current_errors(errors) if iters % opt.print_freq == 0: visualizer.print_current_errors(epoch, iters) state = { 'asr_state_dict': asr_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report } filename = 'latest' utils.save_checkpoint(state, opt.exp_path, filename=filename) if iters % opt.validate_freq == 0: sche_samp_rate = sample_rampup.update(iters) print("iters {} sche_samp_rate {}".format( iters, sche_samp_rate)) asr_model.eval() fbank_model.eval() torch.set_grad_enabled(False) num_saved_attention = 0 for i, (data) in tqdm(enumerate(val_loader, start=0)): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data fbank_features = fbank_model(inputs, fbank_cmvn) # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data loss_ctc, loss_att, acc = asr_model( fbank_features, targets, input_sizes, target_sizes, 0.0) loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att errors = { 'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 'val/acc': acc, 'val/loss_att': loss_att.item() } visualizer.set_current_errors(errors) if opt.num_save_attention > 0 and opt.mtlalpha != 1.0: if num_saved_attention < opt.num_save_attention: att_ws = asr_model.calculate_all_attentions( fbank_features, targets, input_sizes, target_sizes) for x in range(len(utt_ids)): att_w = att_ws[x] utt_id = utt_ids[x] file_name = "{}_ep{}_it{}.png".format( utt_id, epoch, iters) dec_len = int(target_sizes[x]) enc_len = int(input_sizes[x]) visualizer.plot_attention( att_w, dec_len, enc_len, file_name) num_saved_attention += 1 if num_saved_attention >= opt.num_save_attention: break asr_model.train() fbank_model.train() torch.set_grad_enabled(True) visualizer.print_epoch_errors(epoch, iters) acc_report = visualizer.plot_epoch_errors( epoch, iters, 'acc.png') loss_report = visualizer.plot_epoch_errors( epoch, iters, 'loss.png') val_loss = visualizer.get_current_errors('val/loss') val_acc = visualizer.get_current_errors('val/acc') filename = None if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc': if val_acc < best_acc: logging.info('val_acc {} > best_acc {}'.format( val_acc, best_acc)) opt.eps = utils.adadelta_eps_decay( optimizer, opt.eps_decay) else: filename = 'model.acc.best' best_acc = max(best_acc, val_acc) logging.info('best_acc {}'.format(best_acc)) elif args.criterion == 'loss': if val_loss > best_loss: logging.info('val_loss {} > best_loss {}'.format( val_loss, best_loss)) opt.eps = utils.adadelta_eps_decay( optimizer, opt.eps_decay) else: filename = 'model.loss.best' best_loss = min(val_loss, best_loss) logging.info('best_loss {}'.format(best_loss)) state = { 'asr_state_dict': asr_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report } utils.save_checkpoint(state, opt.exp_path, filename=filename) ##filename='epoch-{}_iters-{}_loss-{:.4f}_acc-{:.4f}.pth'.format(epoch, iters, val_loss, val_acc) ##utils.save_checkpoint(state, opt.exp_path, filename=filename) visualizer.reset()
def train_main(args): args.distributed = args.world_size > 1 main_proc = True if args.distributed: if args.gpu_rank: torch.cuda.set_device(int(args.gpu_rank)) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) main_proc = args.rank == 0 # Only the first proc should save models save_folder = args.save_folder loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor( args.epochs) best_wer = None if args.visdom and main_proc: from visdom import Visdom viz = Visdom() opts = dict(title=args.id, ylabel='', xlabel='Epoch', legend=['Loss', 'WER', 'CER']) viz_window = None epochs = torch.arange(1, args.epochs + 1) if args.tensorboard and main_proc: os.makedirs(args.log_dir, exist_ok=True) from tensorboardX import SummaryWriter tensorboard_writer = SummaryWriter(args.log_dir) os.makedirs(save_folder, exist_ok=True) avg_loss, start_epoch, start_iter = 0, 0, 0 if args.continue_from: # Starting from previous model print("Loading checkpoint model %s" % args.continue_from) package = torch.load(args.continue_from, map_location=lambda storage, loc: storage) model = DeepSpeech.load_model_package(package) labels = DeepSpeech.get_labels(model) audio_conf = DeepSpeech.get_audio_conf(model) parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, nesterov=True) if not args.finetune: # Don't want to restart training if args.cuda: model.cuda() optimizer.load_state_dict(package['optim_dict']) start_epoch = int(package.get('epoch', 1)) - 1 # Index start at 0 for training start_iter = package.get('iteration', None) if start_iter is None: start_epoch += 1 # We saved model after epoch finished, start at the next epoch. start_iter = 0 else: start_iter += 1 avg_loss = int(package.get('avg_loss', 0)) loss_results, cer_results, wer_results = package['loss_results'], package[ 'cer_results'], package['wer_results'] if main_proc and args.visdom and \ package[ 'loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph x_axis = epochs[0:start_epoch] y_axis = torch.stack( (loss_results[0:start_epoch], wer_results[0:start_epoch], cer_results[0:start_epoch]), dim=1) viz_window = viz.line( X=x_axis, Y=y_axis, opts=opts, ) if main_proc and args.tensorboard and \ package[ 'loss_results'] is not None and start_epoch > 0: # Previous scores to tensorboard logs for i in range(start_epoch): values = { 'Avg Train Loss': loss_results[i], 'Avg WER': wer_results[i], 'Avg CER': cer_results[i] } tensorboard_writer.add_scalars(args.id, values, i + 1) else: with open(args.labels_path) as label_file: labels = str(''.join(json.load(label_file))) audio_conf = dict(sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, noise_dir=args.noise_dir, noise_prob=args.noise_prob, noise_levels=(args.noise_min, args.noise_max)) rnn_type = args.rnn_type.lower() assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru" model = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, labels=labels, rnn_type=supported_rnns[rnn_type], audio_conf=audio_conf, bidirectional=args.bidirectional) parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, nesterov=True) criterion = CTCLoss() decoder = GreedyDecoder(labels) train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels, normalize=True, augment=args.augment) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, normalize=True, augment=False) if not args.distributed: train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) else: train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size, num_replicas=args.world_size, rank=args.rank) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad: print("Shuffling batches for the following epochs") train_sampler.shuffle(start_epoch) if args.cuda: model.cuda() if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=(int(args.gpu_rank),) if args.rank else None) print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() for epoch in range(start_epoch, args.epochs): model.train() end = time.time() start_epoch_time = time.time() for i, (data) in enumerate(train_loader, start=start_iter): if i == len(train_sampler): break inputs, targets, input_percentages, target_sizes = data input_sizes = input_percentages.mul_(int(inputs.size(3))).int() # measure data loading time data_time.update(time.time() - end) if args.cuda: inputs = inputs.cuda() out, output_sizes = model(inputs, input_sizes) out = out.transpose(0, 1) # TxNxH loss = criterion(out, targets, output_sizes, target_sizes) loss = loss / inputs.size(0) # average the loss by minibatch inf = float("inf") if args.distributed: loss_value = reduce_tensor(loss, args.world_size)[0] else: loss_value = loss.item() if loss_value == inf or loss_value == -inf: print("WARNING: received an inf loss, setting loss value to 0") loss_value = 0 avg_loss += loss_value losses.update(loss_value, inputs.size(0)) # compute gradient optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) # SGD step optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if not args.silent: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( (epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time, data_time=data_time, loss=losses)) if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc: file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth' % (save_folder, epoch + 1, i + 1) print("Saving checkpoint model to %s" % file_path) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i, loss_results=loss_results, wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss), file_path) del loss del out avg_loss /= len(train_sampler) epoch_time = time.time() - start_epoch_time print('Training Summary Epoch: [{0}]\t' 'Time taken (s): {epoch_time:.0f}\t' 'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=avg_loss)) start_iter = 0 # Reset start iteration for next epoch total_cer, total_wer = 0, 0 model.eval() with torch.no_grad(): for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): inputs, targets, input_percentages, target_sizes = data input_sizes = input_percentages.mul_(int(inputs.size(3))).int() # unflatten targets split_targets = [] offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size if args.cuda: inputs = inputs.cuda() out, output_sizes = model(inputs, input_sizes) decoded_output, _ = decoder.decode(out.data, output_sizes) target_strings = decoder.convert_to_strings(split_targets) wer, cer = 0, 0 for x in range(len(target_strings)): transcript, reference = decoded_output[x][0], target_strings[x][0] wer += decoder.wer(transcript, reference) / float(len(reference.split())) cer += decoder.cer(transcript, reference) / float(len(reference)) total_cer += cer total_wer += wer del out wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) wer *= 100 cer *= 100 loss_results[epoch] = avg_loss wer_results[epoch] = wer cer_results[epoch] = cer print('Validation Summary Epoch: [{0}]\t' 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer)) if args.visdom and main_proc: x_axis = epochs[0:epoch + 1] y_axis = torch.stack( (loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1) if viz_window is None: viz_window = viz.line( X=x_axis, Y=y_axis, opts=opts, ) else: viz.line( X=x_axis.unsqueeze(0).expand(y_axis.size(1), x_axis.size(0)).transpose(0, 1), # Visdom fix Y=y_axis, win=viz_window, update='replace', ) if args.tensorboard and main_proc: values = { 'Avg Train Loss': avg_loss, 'Avg WER': wer, 'Avg CER': cer } tensorboard_writer.add_scalars(args.id, values, epoch + 1) if args.log_params: for tag, value in model.named_parameters(): tag = tag.replace('.', '/') tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1) tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1) if args.checkpoint and main_proc: file_path = '%s/deepspeech_%d.pth' % (save_folder, epoch + 1) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, wer_results=wer_results, cer_results=cer_results), file_path) # anneal lr optim_state = optimizer.state_dict() optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal optimizer.load_state_dict(optim_state) print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr'])) if (best_wer is None or best_wer > wer) and main_proc: print("Found better validated model, saving to %s" % args.model_path) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, wer_results=wer_results, cer_results=cer_results), args.model_path) best_wer = wer avg_loss = 0 if not args.no_shuffle: print("Shuffling batches...") train_sampler.shuffle(epoch)
normalize=True, augment=args.augment) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, normalize=True, augment=False) if not args.distributed: train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) else: train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size, num_replicas=args.world_size, rank=args.rank) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad: print("Shuffling batches for the following epochs") train_sampler.shuffle(start_epoch) if args.cuda: model.cuda() if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=(int(args.gpu_rank),) if args.rank else None) print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() for epoch in range(start_epoch, args.epochs):