Example #1
0
with open(args.labels_path) as label_file:
    labels = str(''.join(json.load(label_file)))

audio_conf = dict(sample_rate=args.sample_rate,
                  window_size=args.window_size)

model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                   nb_layers=args.hidden_layers,
                   audio_conf=audio_conf,
                   labels=labels,
                   rnn_type=supported_rnns[rnn_type],
                   mixed_precision=args.mixed_precision)
model = model.to(device)
if args.mixed_precision:
    model = convert_model_to_half(model)
print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

parameters = model.parameters()
optimizer = torch.optim.SGD(parameters, lr=3e-4, momentum=0.9, nesterov=True, weight_decay=1e-5)
if args.distributed:
    model = DistributedDataParallel(model)
if args.mixed_precision:
    optimizer = FP16_Optimizer(optimizer,
                               static_loss_scale=args.static_loss_scale,
                               dynamic_loss_scale=args.dynamic_loss_scale)

criterion = CTCLoss()

seconds = int(args.seconds)
batch_size = int(args.batch_size)
                                      batch_sampler=train_sampler_clean)
 train_loader_adv = AudioDataLoader(train_dataset_adv,
                                    num_workers=args.num_workers,
                                    batch_sampler=train_sampler_adv)
 test_loader = AudioDataLoader(test_dataset,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers)
 '''
 if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad:
     print("Shuffling batches for the following epochs")
     train_sampler.shuffle(start_epoch)
 '''
 model = model.to(device)
 denoiser = denoiser.to(device)
 if args.mixed_precision:
     model = convert_model_to_half(model)
     denoiser = convert_model_to_half(denoiser)
 parameters = denoiser.parameters()
 optimizer = torch.optim.Adam(parameters, lr=args.lr)  #,
 #momentum=args.momentum, nesterov=True, weight_decay=1e-5)
 if args.distributed:
     model = DistributedDataParallel(model)
     denoiser = DistributedDataParallel(denoiser)
 if args.mixed_precision:
     optimizer = FP16_Optimizer(optimizer,
                                static_loss_scale=args.static_loss_scale,
                                dynamic_loss_scale=args.dynamic_loss_scale)
 if optim_state is not None:
     optimizer.load_state_dict(optim_state)
 print(model)
 print("Number of parameters: %d" % DeepSpeech.get_param_size(model))