コード例 #1
0
ファイル: hp_search.py プロジェクト: entn-at/e2e_verification
def train(lr, l2, momentum, patience, latent_size, n_hidden, hidden_size, n_frames, model, ncoef, dropout_prob, epochs, batch_size, n_workers, cuda, train_hdf_file, valid_hdf_file, valid_n_cycles, cp_path, softmax):

	if cuda:
		device=get_freer_gpu()

	train_dataset=Loader_test(hdf5_name=train_hdf_file, max_nb_frames=int(n_frames))
	train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, worker_init_fn=set_np_randomseed)

	valid_dataset = Loader(hdf5_name = valid_hdf_file, max_nb_frames = int(n_frames), n_cycles=valid_n_cycles)
	valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, worker_init_fn=set_np_randomseed)

	if args.model == 'resnet_stats':
		model = model_.ResNet_stats(n_z=int(latent_size), nh=int(n_hidden), n_h=int(hidden_size), proj_size=len(train_dataset.speakers_list), ncoef=ncoef, dropout_prob=dropout_prob, sm_type=softmax)
	elif args.model == 'resnet_mfcc':
		model = model_.ResNet_mfcc(n_z=int(latent_size), nh=int(n_hidden), n_h=int(hidden_size), proj_size=len(train_dataset.speakers_list), ncoef=ncoef, dropout_prob=dropout_prob, sm_type=softmax)
	if args.model == 'resnet_lstm':
		model = model_.ResNet_lstm(n_z=int(latent_size), nh=int(n_hidden), n_h=int(hidden_size), proj_size=len(train_dataset.speakers_list), ncoef=ncoef, dropout_prob=dropout_prob, sm_type=softmax)
	elif args.model == 'resnet_small':
		model = model_.ResNet_small(n_z=int(latent_size), nh=int(n_hidden), n_h=int(hidden_size), proj_size=len(train_dataset.speakers_list), ncoef=ncoef, dropout_prob=dropout_prob, sm_type=softmax)
	elif args.model == 'resnet_large':
		model = model_.ResNet_large(n_z=int(latent_size), nh=int(n_hidden), n_h=int(hidden_size), proj_size=len(train_dataset.speakers_list), ncoef=ncoef, dropout_prob=dropout_prob, sm_type=softmax)

	if cuda:
		model=model.cuda(device)
	else:
		device=None

	optimizer=optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2)

	trainer=TrainLoop(model, optimizer, train_loader, valid_loader, patience=int(patience), verbose=-1, device=device, cp_name=get_file_name(cp_path), save_cp=False, checkpoint_path=cp_path, pretrain=False, cuda=cuda)

	return trainer.train(n_epochs=epochs)
コード例 #2
0
 elif args.model == 'fb':
     model = model_.cnn_lstm_fb(n_z=args.latent_size,
                                proj_size=len(list(labels_dict.keys())),
                                sm_type=args.softmax)
 elif args.model == 'resnet_fb':
     model = model_.ResNet_fb(n_z=args.latent_size,
                              proj_size=len(list(labels_dict.keys())),
                              sm_type=args.softmax)
 elif args.model == 'resnet_mfcc':
     model = model_.ResNet_mfcc(n_z=args.latent_size,
                                proj_size=len(list(labels_dict.keys())),
                                ncoef=args.ncoef,
                                sm_type=args.softmax)
 elif args.model == 'resnet_lstm':
     model = model_.ResNet_lstm(n_z=args.latent_size,
                                proj_size=len(list(labels_dict.keys())),
                                ncoef=args.ncoef,
                                sm_type=args.softmax)
 elif args.model == 'resnet_stats':
     model = model_.ResNet_stats(n_z=args.latent_size,
                                 proj_size=len(list(labels_dict.keys())),
                                 ncoef=args.ncoef,
                                 sm_type=args.softmax)
 elif args.model == 'lcnn9_mfcc':
     model = model_.lcnn_9layers(n_z=args.latent_size,
                                 proj_size=len(list(labels_dict.keys())),
                                 ncoef=args.ncoef,
                                 sm_type=args.softmax)
 elif args.model == 'lcnn29_mfcc':
     model = model_.lcnn_29layers_v2(n_z=args.latent_size,
                                     proj_size=len(list(
                                         labels_dict.keys())),
コード例 #3
0
        raise ValueError(
            'There is no checkpoint/model path. Use arg --cp-path to indicate the path!'
        )

    print('Cuda Mode is: {}'.format(args.cuda))

    if args.cuda:
        device = get_freer_gpu()

    if args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=args.latent_size,
                                   proj_size=None,
                                   ncoef=args.ncoef)
    elif args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=args.latent_size,
                                   proj_size=None,
                                   ncoef=args.ncoef)
    elif args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=args.latent_size,
                                    proj_size=None,
                                    ncoef=args.ncoef)
    elif args.model == 'resnet_large':
        model = model_.ResNet_large(n_z=args.latent_size,
                                    proj_size=None,
                                    ncoef=args.ncoef)

    ckpt = torch.load(args.cp_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(ckpt['model_state'], strict=False)

    model.eval()
コード例 #4
0
ファイル: embedd.py プロジェクト: twistedmove/multitask_asv
            import cupy
            cupy.cuda.Device(int(str(device).split(':')[-1])).use()

    if args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=args.latent_size,
                                   proj_size=0,
                                   ncoef=args.ncoef,
                                   delta=args.delta)
    elif args.model == 'resnet_34':
        model = model_.ResNet_34(n_z=args.latent_size,
                                 proj_size=0,
                                 ncoef=args.ncoef,
                                 delta=args.delta)
    elif args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=args.latent_size,
                                   proj_size=0,
                                   ncoef=args.ncoef,
                                   delta=args.delta)
    elif args.model == 'resnet_qrnn':
        model = model_.ResNet_qrnn(n_z=args.latent_size,
                                   proj_size=0,
                                   ncoef=args.ncoef,
                                   delta=args.delta)
    elif args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=args.latent_size,
                                    proj_size=0,
                                    ncoef=args.ncoef,
                                    delta=args.delta)
    elif args.model == 'resnet_large':
        model = model_.ResNet_large(n_z=args.latent_size,
                                    proj_size=0,
                                    ncoef=args.ncoef,
コード例 #5
0
ファイル: test_load_arch.py プロジェクト: twistedmove/e2e_LID
parser.add_argument('--pairwise',
                    action='store_true',
                    default=False,
                    help='Enables layer-wise comparison of norms')
args = parser.parse_args()

if args.model == 'mfcc':
    model = model_.cnn_lstm_mfcc(n_z=args.latent_size, ncoef=args.ncoef)
if args.model == 'fb':
    model = model_.cnn_lstm_fb(n_z=args.latent_size)
elif args.model == 'resnet_fb':
    model = model_.ResNet_fb(n_z=args.latent_size)
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'lcnn9_mfcc':
    model = model_.lcnn_9layers(n_z=args.latent_size, ncoef=args.ncoef)
elif args.model == 'lcnn29_mfcc':
    model = model_.lcnn_29layers_v2(n_z=args.latent_size, ncoef=args.ncoef)

if args.pairwise:

    if args.model == 'mfcc':
        clone_model = model_.cnn_lstm_mfcc(n_z=args.latent_size,
                                           ncoef=args.ncoef)
    if args.model == 'fb':
        clone_model = model_.cnn_lstm_fb(n_z=args.latent_size)
    elif args.model == 'resnet_fb':
コード例 #6
0
    print('resnet_mfcc', mu.size(), emb.size(), out.size())
if args.model == 'resnet_34' or args.model == 'all':
    batch = torch.rand(3, 3 if args.delta else 1, args.ncoef, 200)
    model = model_.ResNet_34(n_z=args.latent_size,
                             ncoef=args.ncoef,
                             delta=args.delta,
                             proj_size=10,
                             sm_type='softmax')
    mu, emb = model.forward(batch)
    out = model.out_proj(mu, torch.ones(mu.size(0)))
    print('resnet_34', mu.size(), emb.size(), out.size())
if args.model == 'resnet_lstm' or args.model == 'all':
    batch = torch.rand(3, 3 if args.delta else 1, args.ncoef, 200)
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               ncoef=args.ncoef,
                               delta=args.delta,
                               proj_size=10,
                               sm_type='softmax')
    mu, emb = model.forward(batch)
    out = model.out_proj(mu, torch.ones(mu.size(0)))
    print('resnet_lstm', mu.size(), emb.size(), out.size())
if args.model == 'resnet_qrnn' or args.model == 'all' and torch.cuda.is_available(
):
    device = get_freer_gpu()
    import cupy
    cupy.cuda.Device(int(str(device).split(':')[-1])).use()
    batch = torch.rand(3, 3 if args.delta else 1, args.ncoef, 200).to(device)
    model = model_.ResNet_qrnn(n_z=args.latent_size,
                               ncoef=args.ncoef,
                               delta=args.delta,
                               proj_size=10,
コード例 #7
0
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data
import model as model_

# Training settings
parser = argparse.ArgumentParser(description='Test new architectures')
parser.add_argument('--latent-size',
                    type=int,
                    default=200,
                    metavar='S',
                    help='latent layer dimension (default: 200)')
parser.add_argument('--ncoef',
                    type=int,
                    default=23,
                    metavar='N',
                    help='number of MFCCs (default: 23)')
args = parser.parse_args()

batch = torch.rand(3, 1, args.ncoef, 200)
model_s = model_.ResNet_lstm(n_z=args.latent_size,
                             ncoef=args.ncoef,
                             proj_size=10)
model_l = model_.ResNet_mfcc(n_z=args.latent_size,
                             ncoef=args.ncoef,
                             proj_size=20)
mu_l, h, c = model_l(batch)
print(mu_l.size(), h.size(), c.size())
mu = model_s.forward(batch, h, c)
print(mu.size())
コード例 #8
0
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size,
                               nh=args.n_hidden,
                               n_h=args.hidden_size,
                               proj_size=train_dataset.n_speakers,
                               ncoef=args.ncoef,
                               dropout_prob=args.dropout_prob,
                               sm_type=args.softmax,
                               ndiscriminators=args.ndiscriminators,
                               r_proj_size=args.rproj_size)
if args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               nh=args.n_hidden,
                               n_h=args.hidden_size,
                               proj_size=train_dataset.n_speakers,
                               ncoef=args.ncoef,
                               dropout_prob=args.dropout_prob,
                               sm_type=args.softmax,
                               ndiscriminators=args.ndiscriminators,
                               r_proj_size=args.rproj_size)
elif args.model == 'resnet_small':
    model = model_.ResNet_small(n_z=args.latent_size,
                                nh=args.n_hidden,
                                n_h=args.hidden_size,
                                proj_size=train_dataset.n_speakers,
                                ncoef=args.ncoef,
                                dropout_prob=args.dropout_prob,
                                sm_type=args.softmax,
                                ndiscriminators=args.ndiscriminators,
                                r_proj_size=args.rproj_size)
elif args.model == 'resnet_large':
コード例 #9
0
    if args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=args.latent_size,
                                    nh=args.n_hidden,
                                    n_h=args.hidden_size,
                                    proj_size=1,
                                    ncoef=args.ncoef)
    elif args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=args.latent_size,
                                   nh=args.n_hidden,
                                   n_h=args.hidden_size,
                                   proj_size=1,
                                   ncoef=args.ncoef)
    if args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=args.latent_size,
                                   nh=args.n_hidden,
                                   n_h=args.hidden_size,
                                   proj_size=1,
                                   ncoef=args.ncoef)
    elif args.model == 'resnet_small':
        model = model_.ResNet_small(n_z=args.latent_size,
                                    nh=args.n_hidden,
                                    n_h=args.hidden_size,
                                    proj_size=1,
                                    ncoef=args.ncoef)
    elif args.model == 'resnet_large':
        model = model_.ResNet_large(n_z=args.latent_size,
                                    nh=args.n_hidden,
                                    n_h=args.hidden_size,
                                    proj_size=1,
                                    ncoef=args.ncoef)
コード例 #10
0
ファイル: train.py プロジェクト: entn-at/conditional_asv
valid_dataset = Loader(hdf5_name=args.valid_hdf_file,
                       max_nb_frames=args.n_frames,
                       n_cycles=args.valid_n_cycles)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.workers,
                                           worker_init_fn=set_np_randomseed)

if args.cuda:
    device = get_freer_gpu()
else:
    device = None

model_s = model_.ResNet_lstm(n_z=args.latent_size,
                             ncoef=args.ncoef,
                             proj_size=len(train_dataset.speakers_list),
                             sm_type=args.softmax)
model_l = model_.ResNet_mfcc(n_z=args.latent_size,
                             ncoef=args.ncoef,
                             proj_size=len(train_dataset.languages_list),
                             sm_type=args.softmax)

if args.pretrained_s_path is not None:
    ckpt = torch.load(args.pretrained_s_path,
                      map_location=lambda storage, loc: storage)

    try:
        model_s.load_state_dict(ckpt['model_state'], strict=True)
    except RuntimeError as err:
        print("Runtime Error: {0}".format(err))
    except:
コード例 #11
0
                               r_proj_size=args.rproj_size)
    print('resnet_mfcc')
    mu, emb = model.forward(batch)
    print(mu.size())
    emb = torch.cat([emb, emb], 1)
    print(emb.size())
    pred = model.forward_bin(emb)
    print(pred)
    scores_p = model.forward_bin(emb)
    print(scores_p)
if args.model == 'resnet_lstm' or args.model == 'all':
    batch = torch.rand(3, 1, args.ncoef, 200)
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               nh=args.n_hidden,
                               n_h=args.hidden_size,
                               proj_size=100,
                               ncoef=args.ncoef,
                               ndiscriminators=args.ndiscriminators,
                               r_proj_size=args.rproj_size)
    print('resnet_lstm')
    mu, emb = model.forward(batch)
    print(mu.size())
    emb = torch.cat([emb, emb], 1)
    print(emb.size())
    pred = model.forward_bin(emb)
    print(pred)
    scores_p = model.forward_bin(emb)
    print(scores_p)
if args.model == 'resnet_small' or args.model == 'all':
    batch = torch.rand(3, 1, args.ncoef, 200)
    model = model_.ResNet_small(n_z=args.latent_size,
コード例 #12
0
ファイル: train.py プロジェクト: joaomonteirof/asv_base
if args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size,
                               proj_size=train_dataset.n_speakers,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_34':
    model = model_.ResNet_34(n_z=args.latent_size,
                             proj_size=train_dataset.n_speakers,
                             ncoef=args.ncoef,
                             sm_type=args.softmax,
                             delta=args.delta)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               proj_size=train_dataset.n_speakers,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_qrnn':
    model = model_.ResNet_qrnn(n_z=args.latent_size,
                               proj_size=train_dataset.n_speakers,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size,
                                proj_size=train_dataset.n_speakers,
                                ncoef=args.ncoef,
                                sm_type=args.softmax,
                                delta=args.delta)
elif args.model == 'resnet_large':
コード例 #13
0
ファイル: train_hp.py プロジェクト: entn-at/e2e_verification
                                ncoef=args.ncoef,
                                dropout_prob=args.dropout_prob,
                                sm_type=args.softmax)
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size,
                               nh=args.n_hidden,
                               n_h=args.hidden_size,
                               proj_size=len(train_dataset.speakers_list),
                               ncoef=args.ncoef,
                               dropout_prob=args.dropout_prob,
                               sm_type=args.softmax)
if args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               nh=args.n_hidden,
                               n_h=args.hidden_size,
                               proj_size=len(train_dataset.speakers_list),
                               ncoef=args.ncoef,
                               dropout_prob=args.dropout_prob,
                               sm_type=args.softmax)
elif args.model == 'resnet_small':
    model = model_.ResNet_small(n_z=args.latent_size,
                                nh=args.n_hidden,
                                n_h=args.hidden_size,
                                proj_size=len(train_dataset.speakers_list),
                                ncoef=args.ncoef,
                                dropout_prob=args.dropout_prob,
                                sm_type=args.softmax)
elif args.model == 'resnet_large':
    model = model_.ResNet_large(n_z=args.latent_size,
                                nh=args.n_hidden,
                                n_h=args.hidden_size,
コード例 #14
0
ファイル: train_olr.py プロジェクト: joaomonteirof/e2e_LID
                               sm_type=args.softmax)
elif args.model == 'resnet_fb':
    model = model_.ResNet_fb(n_z=args.latent_size,
                             proj_size=len(train_dataset.speakers_list)
                             if args.softmax != 'none' else 0,
                             sm_type=args.softmax)
elif args.model == 'resnet_mfcc':
    model = model_.ResNet_mfcc(n_z=args.latent_size,
                               proj_size=len(train_dataset.speakers_list)
                               if args.softmax != 'none' else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               proj_size=len(train_dataset.speakers_list)
                               if args.softmax != 'none' else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size,
                                proj_size=len(train_dataset.speakers_list)
                                if args.softmax != 'none' else 0,
                                ncoef=args.ncoef,
                                sm_type=args.softmax)
elif args.model == 'lcnn9_mfcc':
    model = model_.lcnn_9layers(n_z=args.latent_size,
                                proj_size=len(train_dataset.speakers_list)
                                if args.softmax != 'none' else 0,
                                ncoef=args.ncoef,
                                sm_type=args.softmax)
elif args.model == 'lcnn29_mfcc':
コード例 #15
0
    if args.cuda:
        device = get_freer_gpu()

    ckpt = torch.load(args.cp_path, map_location=lambda storage, loc: storage)

    if args.model == 'resnet_mfcc':
        model = model_.ResNet_mfcc(n_z=ckpt['latent_size'],
                                   nh=ckpt['n_hidden'],
                                   n_h=ckpt['hidden_size'],
                                   proj_size=ckpt['r_proj_size'],
                                   ncoef=ckpt['ncoef'],
                                   ndiscriminators=ckpt['ndiscriminators'])
    elif args.model == 'resnet_lstm':
        model = model_.ResNet_lstm(n_z=ckpt['latent_size'],
                                   nh=ckpt['n_hidden'],
                                   n_h=ckpt['hidden_size'],
                                   proj_size=ckpt['r_proj_size'],
                                   ncoef=ckpt['ncoef'],
                                   ndiscriminators=ckpt['ndiscriminators'])
    elif args.model == 'resnet_stats':
        model = model_.ResNet_stats(n_z=ckpt['latent_size'],
                                    nh=ckpt['n_hidden'],
                                    n_h=ckpt['hidden_size'],
                                    proj_size=ckpt['r_proj_size'],
                                    ncoef=ckpt['ncoef'],
                                    ndiscriminators=ckpt['ndiscriminators'])
    elif args.model == 'resnet_small':
        model = model_.ResNet_small(n_z=ckpt['latent_size'],
                                    nh=ckpt['n_hidden'],
                                    n_h=ckpt['hidden_size'],
                                    proj_size=ckpt['r_proj_size'],
                                    ncoef=ckpt['ncoef'],
コード例 #16
0
                               proj_size=train_dataset.n_speakers if
                               args.softmax != 'none' or args.pretrain else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_34':
    model = model_.ResNet_34(n_z=args.latent_size,
                             proj_size=train_dataset.n_speakers
                             if args.softmax != 'none' or args.pretrain else 0,
                             ncoef=args.ncoef,
                             sm_type=args.softmax,
                             delta=args.delta)
elif args.model == 'resnet_lstm':
    model = model_.ResNet_lstm(n_z=args.latent_size,
                               proj_size=train_dataset.n_speakers if
                               args.softmax != 'none' or args.pretrain else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_qrnn':
    model = model_.ResNet_qrnn(n_z=args.latent_size,
                               proj_size=train_dataset.n_speakers if
                               args.softmax != 'none' or args.pretrain else 0,
                               ncoef=args.ncoef,
                               sm_type=args.softmax,
                               delta=args.delta)
elif args.model == 'resnet_stats':
    model = model_.ResNet_stats(n_z=args.latent_size,
                                proj_size=train_dataset.n_speakers if
                                args.softmax != 'none' or args.pretrain else 0,
                                ncoef=args.ncoef,
                                sm_type=args.softmax,
コード例 #17
0
def train(lr, l2, max_gnorm, momentum, margin, lambda_, swap, latent_size, n_frames, model, ncoef, epochs, batch_size, valid_batch_size, n_workers, cuda, train_hdf_file, valid_hdf_file, cp_path, softmax, delta, logdir):

	if cuda:
		device=get_freer_gpu()
		if args.model == 'resnet_qrnn':
			import cupy
			cupy.cuda.Device(int(str(device).split(':')[-1])).use()

	cp_name = get_file_name(cp_path)

	if args.logdir:
		from torch.utils.tensorboard import SummaryWriter
		writer = SummaryWriter(log_dir=logdir+cp_name, comment=args.model, purge_step=True)
	else:
		writer = None

	train_dataset = Loader(hdf5_name = train_hdf_file, max_nb_frames = int(n_frames), delta = delta)
	train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n.workers, worker_init_fn=set_np_randomseed)

	valid_dataset = Loader_valid(hdf5_name = valid_hdf_file, max_nb_frames = int(n_frames), delta = delta)
	valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=True, num_workers=n_workers, worker_init_fn=set_np_randomseed)

	if model == 'resnet_mfcc':
		model=model_.ResNet_mfcc(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta)
	elif model == 'resnet_34':
		model=model_.ResNet_34(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta)
	elif model == 'resnet_lstm':
		model=model_.ResNet_lstm(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta)
	elif model == 'resnet_qrnn':
		model=model_.ResNet_qrnn(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta)
	elif model == 'resnet_stats':
		model=model_.ResNet_stats(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'resnet_large':
		model = model_.ResNet_large(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'resnet_small':
		model = model_.ResNet_small(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'resnet_2d':
		model = model_.ResNet_2d(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN':
		model = model_.TDNN(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_att':
		model = model_.TDNN_att(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_multihead':
		model = model_.TDNN_multihead(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_lstm':
		model = model_.TDNN_lstm(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_aspp':
		model = model_.TDNN_aspp(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_mod':
		model = model_.TDNN_mod(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'TDNN_multipool':
		model = model_.TDNN_multipool(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)
	elif args.model == 'transformer':
		model = model_.transformer_enc(n_z=int(latent_size), proj_size=train_dataset.n_speakers, ncoef=args.ncoef, sm_type=softmax, delta=delta)

	if cuda:
		model=model.to(device)
	else:
		device=None

	optimizer=optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2)

	trainer=TrainLoop(model, optimizer, train_loader, valid_loader, max_gnorm=max_gnorm, margin=margin, lambda_=lambda_, verbose=-1, device=device, cp_name=cp_name, save_cp=True, checkpoint_path=cp_path, swap=swap, softmax=True, pretrain=False, mining=True, cuda=cuda, logger=writer)

	return trainer.train(n_epochs=epochs)