コード例 #1
0
ファイル: train_olr.py プロジェクト: joaomonteirof/e2e_LID
    device = None

if args.model == 'mfcc':
    model = model_.cnn_lstm_mfcc(n_z=args.latent_size,
                                 proj_size=len(train_dataset.speakers_list)
                                 if args.softmax != 'none' else 0,
                                 ncoef=args.ncoef,
                                 sm_type=args.softmax)
elif args.model == 'fb':
    model = model_.cnn_lstm_fb(n_z=args.latent_size,
                               proj_size=len(train_dataset.speakers_list)
                               if args.softmax != 'none' else 0,
                               sm_type=args.softmax)
elif args.model == 'resnet_fb':
    model = model_.ResNet_fb(n_z=args.latent_size,
                             proj_size=len(train_dataset.speakers_list)
                             if args.softmax != 'none' else 0,
                             sm_type=args.softmax)
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size,
                               proj_size=len(train_dataset.speakers_list)
                               if args.softmax != 'none' else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               proj_size=len(train_dataset.speakers_list)
                               if args.softmax != 'none' else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size,
コード例 #2
0
            'There is no checkpoint/model path. Use arg --cp-path to indicate the path!'
        )

    print('Cuda Mode is: {}'.format(args.cuda))

    if args.cuda:
        set_device()

    if args.model == 'mfcc':
        model = model_.cnn_lstm_mfcc(n_z=args.latent_size,
                                     proj_size=None,
                                     ncoef=args.ncoef)
    elif args.model == 'fb':
        model = model_.cnn_lstm_fb(n_z=args.latent_size, proj_size=None)
    elif args.model == 'resnet_fb':
        model = model_.ResNet_fb(n_z=args.latent_size, proj_size=None)
    elif args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=args.latent_size,
                                   proj_size=None,
                                   ncoef=args.ncoef)
    elif args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=args.latent_size,
                                   proj_size=None,
                                   ncoef=args.ncoef)
    elif args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=args.latent_size,
                                    proj_size=None,
                                    ncoef=args.ncoef)
    elif args.model == 'lcnn9_mfcc':
        model = model_.lcnn_9layers(n_z=args.latent_size,
                                    proj_size=None,
コード例 #3
0
    if args.cuda:
        set_device()

    if args.model == 'mfcc':
        model = model_.cnn_lstm_mfcc(n_z=args.latent_size,
                                     proj_size=len(list(labels_dict.keys())),
                                     ncoef=args.ncoef,
                                     sm_type=args.softmax)
    elif args.model == 'fb':
        model = model_.cnn_lstm_fb(n_z=args.latent_size,
                                   proj_size=len(list(labels_dict.keys())),
                                   sm_type=args.softmax)
    elif args.model == 'resnet_fb':
        model = model_.ResNet_fb(n_z=args.latent_size,
                                 proj_size=len(list(labels_dict.keys())),
                                 sm_type=args.softmax)
    elif args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=args.latent_size,
                                   proj_size=len(list(labels_dict.keys())),
                                   ncoef=args.ncoef,
                                   sm_type=args.softmax)
    elif args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=args.latent_size,
                                   proj_size=len(list(labels_dict.keys())),
                                   ncoef=args.ncoef,
                                   sm_type=args.softmax)
    elif args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=args.latent_size,
                                    proj_size=len(list(labels_dict.keys())),
                                    ncoef=args.ncoef,
コード例 #4
0
ファイル: test_load_arch.py プロジェクト: twistedmove/e2e_LID
                    type=int,
                    default=13,
                    metavar='N',
                    help='number of MFCCs (default: 23)')
parser.add_argument('--pairwise',
                    action='store_true',
                    default=False,
                    help='Enables layer-wise comparison of norms')
args = parser.parse_args()

if args.model == 'mfcc':
    model = model_.cnn_lstm_mfcc(n_z=args.latent_size, ncoef=args.ncoef)
if args.model == 'fb':
    model = model_.cnn_lstm_fb(n_z=args.latent_size)
elif args.model == 'resnet_fb':
    model = model_.ResNet_fb(n_z=args.latent_size)
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'lcnn9_mfcc':
    model = model_.lcnn_9layers(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'lcnn29_mfcc':
    model = model_.lcnn_29layers_v2(n_z=args.latent_size, ncoef=args.ncoef)

if args.pairwise:

    if args.model == 'mfcc':
        clone_model = model_.cnn_lstm_mfcc(n_z=args.latent_size,