Beispiel #1
0
                         att_size=20, cla_size=350, cla_layers=5,
                         num_classes=59,
                         tra_type='identity', rnn_mode='LSTM', cla_dropout=0.3)

    # Load network
    net.load_state_dict(torch.load(net_file, map_location=lambda storage, loc: storage))
    net.eval()
    if cuda == True:
        net.cuda()

    # Loop over subsets
    for subset in all_subsets:

        # Create dataset, sampler, loader
        composed = transforms.Compose([warp_ctc_shift(), standardization('sample')])
        testset = dset(h5file=h5_file, dataset_mode=dataset_mode, subset=subset, transform=composed)
        sampler = HighThroughputSampler(testset, shuffle_batches=False, num_splits=1, max_frames=max_frame_cache,
                                        debug=0, roll=False)
        batch_sampler = SimpleBatchSampler(sampler=sampler)
        test_loader = torch.utils.data.DataLoader(testset, batch_sampler=batch_sampler, num_workers=0,
                                                  collate_fn=collate_fn)

        # Get network output
        inference_list = infer(model=net, dataloader=test_loader, cuda=cuda)
        PER, WER, CER = error_preliminary(inference_list, subset)
        for pos_scale in all_pos_scales:

            kw = kaldi_converter(inference_list, filename='{}_{}'.format(subset, pos_scale), scale=pos_scale)
            logprob = kw.write_logprob()
            reference = kw.write_reference()
        pass
    log_dict = collections.OrderedDict()

    # Prepare transforms

    train_transforms = [tl.warp_ctc_shift(), tl.standardization(args.standardization), tl.gaussian_noise(args.noise)]
    val_transforms = [tl.warp_ctc_shift(), tl.standardization(args.standardization)]

    if args.concatenation == True:
        train_transforms.append(tl.concatenation())
        val_transforms.append(tl.concatenation())
    train_composed = transforms.Compose(train_transforms)
    val_composed = transforms.Compose(val_transforms)

    # Create datasets
    trainset = dset(h5file=args.dataset, dataset_mode=args.dataset_mode, subset=args.trainset, speaker_wise=False,
                    transform=train_composed)
    valset = dset(h5file=args.dataset, dataset_mode=args.dataset_mode, subset=args.valset, speaker_wise=False,
                  transform=val_composed)

    # Define sampler
    train_sampler = HighThroughputSampler(trainset, shuffle_batches=True, num_splits=3, max_frames=args.max_frame_cache,
                                          debug=0)
    train_batch_sampler = SimpleBatchSampler(sampler=train_sampler)
    val_sampler = HighThroughputSampler(valset, shuffle_batches=False, roll=True, max_frames=args.max_frame_cache,
                                        debug=0)
    val_batch_sampler = SimpleBatchSampler(sampler=val_sampler)
    train_loader = torch.utils.data.DataLoader(trainset, batch_sampler=train_batch_sampler, num_workers=1,
                                               collate_fn=collate_fn)
    val_loader = torch.utils.data.DataLoader(valset, batch_sampler=val_batch_sampler, num_workers=1,
                                             collate_fn=collate_fn)