示例#1
0
def main():
    # print the experiment configuration
    print('\33[91m \nCurrent time is {}\33[0m'.format(str(time.asctime())))
    # print('Parsed options:\n{}\n'.format(vars(args)))
    print('Number of Speakers: {}\n'.format(len(train_dir.classes)))

    context = [[-2, 2], [-2, 0, 2], [-3, 0, 3], [0], [0]]
    # the same configure as x-vector
    node_num = [64, 128, 256, 512, 1024, 1024, 512, 512]
    full_context = [True, False, False, True, True]

    # train_set = trainset.TrainSet('../all_feature/')
    # todo:
    # train_set = []
    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size,
                                               collate_fn=PadCollate(dim=2),
                                               shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.test_batch_size,
                                               ollate_fn=PadCollate(dim=2),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    model = Time_Delay(context, 64, len(train_dir.classes), node_num, full_context)

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model, args.lr)
    # torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # torch.set_num_threads(16)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            checkpoint = torch.load(args.resume)

            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}

            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
            # criterion.load_state_dict(checkpoint['criterion'])

        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start = args.start_epoch
    print('start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    for epoch in range(start, end):
        # pdb.set_trace()
        train(train_loader, model, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\33[91mCurrent time is {}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    model = XVectorTDNN(len(train_dir.speakers), dropout_p=0.0)

    if args.cuda:
        model.cuda()

    # pdb.set_trace()
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=int(args.batch_size /
                                                              2),
                                               collate_fn=PadCollate(dim=1),
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)
    # optionally resume from a checkpoint

    epochs = np.arange(1, 15)
    for epoch in epochs:
        if os.path.isfile(args.resume.format(epoch)):
            print('=> loading checkpoint {}'.format(args.resume.format(epoch)))
            checkpoint = torch.load(args.resume.format(epoch))
            start = checkpoint['epoch']

            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model.load_state_dict(filtered)
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

        model.eval()

        test(test_loader, valid_loader, model, start)
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    test_display_triplet_distance = False

    # print the experiment configuration
    print('\33[91m\nCurrent time is {}\33[0m'.format(str(time.asctime())))
    # print('Parsed options:\n{}\n'.format(vars(args)))
    print('Number of Speakers: {}\n'.format(len(train_dir.classes)))

    # instantiate model and initialize weights
    model = SuperficialResCNN(layers=[1, 1, 1, 1],
                              embedding_size=args.embedding_size,
                              n_classes=len(train_dir.classes),
                              m=args.margin)
    # model = ResCNNSpeaker(embedding_size=args.embedding_size, resnet_size=10, num_classes=len(train_dir.classes))

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model, args.lr)
    # criterion = AngularSoftmax(in_feats=args.embedding_size,
    #                           num_classes=len(train_dir.classes))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            checkpoint = torch.load(args.resume)

            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }

            model.load_state_dict(filtered)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # criterion.load_state_dict(checkpoint['criterion'])

        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start = args.start_epoch
    print('start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs
    scheduler = StepLR(optimizer, step_size=15, gamma=0.1)

    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               collate_fn=PadCollate(dim=2),
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=args.test_batch_size,
                                               shuffle=False,
                                               collate_fn=PadCollate(dim=2),
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    for epoch in range(start, end):
        # pdb.set_trace()
        train(train_loader, model, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
        scheduler.step()
        # exit(1)

    writer.close()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {'embedding_size': args.embedding_size,
                    'inst_norm': args.inst_norm,
                    'resnet_size': args.resnet_size,
                    'num_classes': train_dir.num_spks,
                    'channels': channels,
                    'avg_size': args.avg_size,
                    'alpha': args.alpha,
                    'kernel_size': kernel_size,
                    'padding': padding,
                    'dropout_p': args.dropout_p}

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            #
            model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{'params': xe_criterion.parameters(), 'lr': args.lr * 5},
                                     {'params': model.parameters()}],
                                    lr=args.lr, weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters())
            optimizer = torch.optim.SGD([{'params': model.classifier.parameters(), 'lr': args.lr * 5},
                                         {'params': rest_params}],
                                        lr=args.lr, weight_decay=args.weight_decay,
                                        momentum=args.momentum)

    if args.scheduler == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    else:
        milestones = args.milestones.split(',')
        milestones = [int(x) for x in milestones]
        milestones.sort()
        scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size,
                                               collate_fn=PadCollate(dim=2, fix_len=False,
                                                                     min_chunk_size=250, max_chunk_size=450),
                                               shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2),
                                               collate_fn=PadCollate(dim=2, fix_len=False,
                                                                     min_chunk_size=250, max_chunk_size=450),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False, **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)

    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        print('Dropout is {}.'.format(model.dropout_p))

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, ce, optimizer, epoch)
        if epoch % 4 == 1 or epoch == (end - 1):
            check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch)
            torch.save({'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'criterion': ce},
                       check_path)

        if epoch % 2 == 1 and epoch != (end - 1):
            test(test_loader, valid_loader, model, epoch)
        # sitw_test(sitw_test_loader, model, epoch)
        # sitw_test(sitw_dev_loader, model, epoch)
        scheduler.step()
        # exit(1)

    extract_dir = KaldiExtractDataset(dir=args.test_dir, transform=transform_V, filer_loader=file_loader)
    extract_loader = torch.utils.data.DataLoader(extract_dir, batch_size=1, shuffle=False, **kwargs)
    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')
    verification_extract(extract_loader, model, xvector_dir)

    verify_dir = ScriptVerifyDataset(dir=args.test_dir, trials_file=args.trials, xvectors_dir=xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir, batch_size=64, shuffle=False, **kwargs)
    verification_test(test_loader=verify_loader, dist_type=('cos' if args.cos_sim else 'l2'),
                      log_interval=args.log_interval)

    writer.close()
def main():
    # print the experiment configuration
    print('\33[91m \nCurrent time is {}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}\n'.format(len(train_dir.classes)))
    # device = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')

    context = [[-2, 2], [-2, 0, 2], [-3, 0, 3], [0], [0]]
    # the same configure as x-vector
    node_num = [512, 512, 512, 512, 1500, 3000, 512, 512]
    full_context = [True, False, False, True, True]

    # train_set = trainset.TrainSet('../all_feature/')
    # todo:
    # train_set = []
    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True,
                                               collate_fn=PadCollate(dim=2), **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size, shuffle=False,
                                               collate_fn=PadCollate(dim=2), **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs)
    # train_loader = DataLoader(train_dir, batch_size=args.batch_size, shuffle=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size, shuffle=False)
    # test_loader = DataLoader(test_dir, batch_size=args.test_batch_size, shuffle=False)

    model = Time_Delay(context, 24, len(train_dir.classes), node_num, full_context)

    if args.cuda:
        # model.to(device)
        model = model.cuda()

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[16, 24], gamma=0.1)
    ce_loss = nn.CrossEntropyLoss().cuda()
    # torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # torch.set_num_threads(16)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
            # criterion.load_state_dict(checkpoint['criterion'])

        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start = args.start_epoch
    print('start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    for epoch in range(start, end):
        # pdb.set_trace()
        train(train_loader, model, ce_loss, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
        scheduler.step()
示例#6
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    model_kwargs = {
        'embedding_size': args.embedding_size,
        'num_classes': train_dir.num_spks,
        'input_dim': args.feat_dim,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    # model = ASTDNN(num_classes=train_dir.num_spks, input_dim=args.feat_dim,
    #                embedding_size=args.embedding_size,
    #                dropout_p=args.dropout_p)

    start_epoch = 0
    if args.save_init:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model_dict = model.state_dict()
            model_dict.update(filtered)

            model.load_state_dict(model_dict)
            #
            try:
                model.dropout.p = args.dropout_p
            except:
                pass
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            optimizer = torch.optim.SGD(
                [{
                    'params': model.classifier.parameters(),
                    'lr': args.lr * 5
                }, {
                    'params': rest_params
                }],
                lr=args.lr,
                weight_decay=args.weight_decay,
                momentum=args.momentum)

    if args.scheduler == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    else:
        milestones = args.milestones.split(',')
        milestones = [int(x) for x in milestones]
        milestones.sort()
        scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               collate_fn=PadCollate(
                                                   dim=2,
                                                   fix_len=False,
                                                   min_chunk_size=250,
                                                   max_chunk_size=450),
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        valid_dir,
        batch_size=int(args.batch_size / 2),
        collate_fn=PadCollate(dim=2,
                              fix_len=False,
                              min_chunk_size=250,
                              max_chunk_size=450),
        shuffle=False,
        **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)

    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(
            args.optimizer),
              end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, ce, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
        # sitw_test(sitw_test_loader, model, epoch)
        # sitw_test(sitw_dev_loader, model, epoch)
        scheduler.step()
        # exit(1)

    writer.close()
示例#7
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive

    # print the experiment configuration
    print('\33[91m\nCurrent time is {}.\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}\n'.format(len(train_dir.classes)))

    # instantiate model and initialize weights
    model = DeepSpeakerModel(resnet_size=10,
                             embedding_size=args.embedding_size,
                             num_classes=len(train_dir.classes))

    if args.cuda:
        model.cuda()

    if args.data_parallel:
        model = torch.nn.DataParallel(model, device_ids=[2, 3])

    optimizer = create_optimizer(model, args.lr)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # Filter that remove uncessary component in checkpoint file
            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }

            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start = args.start_epoch
    print('start epoch is : ' + str(start))
    #start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(
        train_dir,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=TripletPadCollate(dim=2),
        **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=args.test_batch_size,
                                               shuffle=False,
                                               collate_fn=PadCollate(dim=2),
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    for epoch in range(start, end):
        train(train_loader, model, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
        # break

    writer.close()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]

    context = args.context.split(',')
    context = [int(x) for x in context]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'context': context,
        'filter_fix': args.filter_fix,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'exp': args.exp,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    dist_type = 'cos' if args.cos_sim else 'l2'
    print('Testing with %s distance, ' % dist_type)

    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    iteration = 0  # if args.resume else 0
    if args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'gaussian':
        xe_criterion = GaussianLoss(num_classes=train_dir.num_spks,
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'coscenter':
        xe_criterion = CenterCosLoss(num_classes=train_dir.num_spks,
                                     feat_dim=args.embedding_size)
    elif args.loss_type == 'mulcenter':
        xe_criterion = MultiCenterLoss(num_classes=train_dir.num_spks,
                                       feat_dim=args.embedding_size,
                                       num_center=args.num_center)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)
    elif args.loss_type == 'arcsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = ArcSoftmaxLoss(margin=args.margin,
                                      s=args.s,
                                      iteraion=iteration,
                                      all_iteraion=args.all_iteraion)
    elif args.loss_type == 'wasse':
        xe_criterion = Wasserstein_Loss(source_cls=args.source_cls)
    elif args.loss_type == 'ring':
        xe_criterion = RingLoss(ring=args.ring)
        args.alpha = 0.0

    model_para = model.parameters()
    if args.loss_type in [
            'center', 'mulcenter', 'gaussian', 'coscenter', 'ring'
    ]:
        assert args.lr_ratio > 0
        model_para = [{
            'params': xe_criterion.parameters(),
            'lr': args.lr * args.lr_ratio
        }, {
            'params': model.parameters()
        }]
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            assert args.lr_ratio > 0
            model_para = [{
                'params': model.classifier.parameters(),
                'lr': args.lr * args.lr_ratio
            }, {
                'params': rest_params
            }]

    if args.filter in ['fDLR', 'fBLayer', 'fLLayer', 'fBPLayer']:
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model.parameters())
        model_para = [{
            'params': model.filter_layer.parameters(),
            'lr': args.lr * args.lr_ratio
        }, {
            'params': rest_params
        }]

    optimizer = create_optimizer(model_para, args.optimizer, **opt_kwargs)

    if not args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }

            # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    # Save model config txt
    with open(
            osp.join(
                args.check_path,
                'model.%s.conf' % time.strftime("%Y.%m.%d", time.localtime())),
            'w') as f:
        f.write('model: ' + str(model) + '\n')
        f.write('CrossEntropy: ' + str(ce_criterion) + '\n')
        f.write('Other Loss: ' + str(xe_criterion) + '\n')
        f.write('Optimizer: ' + str(optimizer) + '\n')

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    if args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == 'rop':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=args.patience,
                                                   min_lr=1e-5)
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=milestones,
                                             gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(
        train_dir,
        batch_size=args.batch_size,
        collate_fn=PadCollate(dim=2,
                              num_batch=int(
                                  np.ceil(len(train_dir) / args.batch_size)),
                              min_chunk_size=args.min_chunk_size,
                              max_chunk_size=args.max_chunk_size),
        shuffle=args.shuffle,
        **kwargs)

    valid_loader = torch.utils.data.DataLoader(
        valid_dir,
        batch_size=int(args.batch_size / 2),
        collate_fn=PadCollate(dim=2,
                              fix_len=True,
                              min_chunk_size=args.chunk_size,
                              max_chunk_size=args.chunk_size + 1),
        shuffle=False,
        **kwargs)
    train_extract_loader = torch.utils.data.DataLoader(train_extract_dir,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       **extract_kwargs)

    if args.cuda:
        if len(args.gpu_id) > 1:
            print("Continue with gpu: %s ..." % str(args.gpu_id))
            torch.distributed.init_process_group(
                backend="nccl",
                # init_method='tcp://localhost:23456',
                init_method=
                'file:///home/ssd2020/yangwenhao/lstm_speaker_verification/data/sharedfile',
                rank=0,
                world_size=1)
            model = DistributedDataParallel(model.cuda(),
                                            find_unused_parameters=True)

        else:
            model = model.cuda()

        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')

    start_time = time.time()
    try:

        for epoch in range(start, end):
            # pdb.set_trace()
            lr_string = '\n\33[1;34m Current \'{}\' learning rate is '.format(
                args.optimizer)
            for param_group in optimizer.param_groups:
                lr_string += '{:.6f} '.format(param_group['lr'])
            print('%s \33[0m' % lr_string)

            train(train_loader, model, ce, optimizer, epoch)
            valid_loss = valid_class(valid_loader, model, ce, epoch)

            if (epoch == 1 or epoch !=
                (end - 2)) and (epoch % 4 == 1 or epoch in milestones
                                or epoch == (end - 1)):
                model.eval()
                check_path = '{}/checkpoint_{}.pth'.format(
                    args.check_path, epoch)
                model_state_dict = model.module.state_dict() \
                                       if isinstance(model, DistributedDataParallel) else model.state_dict(),
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model_state_dict,
                        'criterion': ce
                    }, check_path)

                valid_test(train_extract_loader, model, epoch, xvector_dir)
                test(model, epoch, writer, xvector_dir)
                if epoch != (end - 1):
                    try:
                        shutil.rmtree("%s/train/epoch_%s" %
                                      (xvector_dir, epoch))
                        shutil.rmtree("%s/test/epoch_%s" %
                                      (xvector_dir, epoch))
                    except Exception as e:
                        print('rm dir xvectors error:', e)

            if args.scheduler == 'rop':
                scheduler.step(valid_loss)
            else:
                scheduler.step()

    except KeyboardInterrupt:
        end = epoch

    writer.close()
    stop_time = time.time()
    t = float(stop_time - start_time)
    print("Running %.4f minutes for each epoch.\n" % (t / 60 /
                                                      (max(end - start, 1))))
    exit(0)