Exemplo n.º 1
0
                      "{:.3f}".format(average_latent_loss0), "latent1=",
                      "{:.3f}".format(average_latent_loss1))

        if (epoch % save_epoch == 0) or (epoch == training_epochs - 1):
            torch.save(scan.state_dict(),
                       '{}/scan_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()

dae = DAE()
vae = VAE()
scan = SCAN()
if use_cuda:
    dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
    vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth'))
    scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth'))
    dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda()
else:
    dae.load_state_dict(
        torch.load('save/dae/dae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    vae.load_state_dict(
        torch.load('save/vae/vae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    scan.load_state_dict(
        torch.load(exp + '/' + opt.load,
                   map_location=lambda storage, loc: storage))

if opt.train:
Exemplo n.º 2
0
            target = torch.transpose(target, 1, 3)
            if use_cuda: hsv_image_t = target.data.cpu().numpy()
            else: hsv_image_t = target.data.numpy()
            rgb_image_t = utils.convert_hsv_to_rgb(hsv_image_t[0])
            utils.save_image(rgb_image_t,
                             "{}/target_epoch_{}.png".format(exp, epoch))

        # Save to checkpoint
        if (epoch % save_epoch == 0) or (epoch == training_epochs - 1):
            torch.save(dae.state_dict(),
                       '{}/dae_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()
dae = DAE()
if opt.load != '':
    print('loading {}'.format(opt.load))
    if use_cuda:
        dae.load_state_dict(torch.load())
    else:
        dae.load_state_dict(
            torch.load(exp + '/' + opt.load,
                       map_location=lambda storage, loc: storage))

if use_cuda: dae = dae.cuda()

if opt.train:
    dae_optimizer = optim.Adam(dae.parameters(), lr=1e-4, eps=1e-8)
    train_dae(dae, data_manager, dae_optimizer)
Exemplo n.º 3
0
parser.add_argument('--output-path',
                    default=None,
                    type=str,
                    help="Where to save raw acoustic output")
parser = add_decoder_args(parser)
parser.add_argument('--save-output',
                    action="store_true",
                    help="Saves output of model from test")
args = parser.parse_args()

if __name__ == '__main__':
    torch.set_grad_enabled(False)
    device = torch.device("cuda" if args.cuda else "cpu")
    model = load_model(device, args.model_path, args.cuda)
    denoiser = DAE()
    denoiser.load_state_dict(
        torch.load('./models/denoiser_deepspeech_final.pth'))
    denoiser = denoiser.to(device)
    denoiser.eval()

    if args.decoder == "beam":
        from decoder import BeamCTCDecoder

        decoder = BeamCTCDecoder(model.labels,
                                 lm_path=args.lm_path,
                                 alpha=args.alpha,
                                 beta=args.beta,
                                 cutoff_top_n=args.cutoff_top_n,
                                 cutoff_prob=args.cutoff_prob,
                                 beam_width=args.beam_width,
                                 num_processes=args.lm_workers)
    elif args.decoder == "greedy":