def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    num_spks = 1211
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(num_spks))

    # instantiate model and initialize weights
    model = LocalResNet(resnet_size=10, embedding_size=args.embedding_size, num_classes=num_spks)
    # start_epoch = 0
    if args.loss_type == 'asoft':
        model.classifier = AngleLinear(in_features=args.embedding_size, out_features=num_spks, m=args.m)

    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=num_spks)

    # ['soft', 'asoft', 'center', 'amsoft'], optionally resume from a checkpoint
    start = 1
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    for epoch in range(start, end):
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch)
        if os.path.isfile(check_path):
            print('=> loading checkpoint {}'.format(check_path))
            checkpoint = torch.load(check_path)
            # pdb.set_trace()
            e = checkpoint['epoch']

            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model_dict = model.state_dict()
            model_dict.update(filtered)

            model.load_state_dict(model_dict)
            ce = checkpoint['criterion']

            torch.save({'epoch': e,
                        'model': model,
                        'criterion': ce},
                       check_path + '.new')

            print('=> Saving new checkpoint at {}'.format(check_path + '.new'))
        else:
            print('=> no checkpoint found at {}'.format(check_path))
Exemple #2
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    num_spks = 1211
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(num_spks))

    # instantiate model and initialize weights
    # model = SuperficialResCNN(layers=[1, 1, 1, 0], embedding_size=args.embedding_size,
    #                           n_classes=num_spks, m=args.margin)

    model = LocalResNet(resnet_size=10, embedding_size=args.embedding_size, num_classes=num_spks)

    if args.loss_type == 'asoft':
        model.classifier = AngleLinear(in_features=args.embedding_size, out_features=num_spks, m=args.m)
    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=num_spks)

    if args.cuda:
        model.cuda()

    # optionally resume from a checkpoint
    sitw_test_loader = torch.utils.data.DataLoader(sitw_dev_dir, batch_size=args.test_batch_size,
                                                   shuffle=False, **kwargs)
    sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_dir, batch_size=args.test_batch_size,
                                                  shuffle=False, **kwargs)
    epochs = np.arange(1, args.epochs + 1)
    resume_path = args.check_path + '/checkpoint_{}.pth'
    for epoch in epochs:
        # Load model from Checkpoint file
        if os.path.isfile(resume_path.format(epoch)):
            print('=> loading checkpoint {}'.format(resume_path.format(epoch)))

            checkpoint = torch.load(resume_path.format(epoch))
            start_epoch = checkpoint['epoch']
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)
        else:
            print('=> no checkpoint found at %s' % resume_path.format(epoch))
            continue

        sitw_test(sitw_dev_loader, sitw_test_loader, model, start_epoch)

    writer.close()
def main():
    print('\nNumber of Speakers: {}.'.format(train_dir.num_spks))
    # print the experiment configuration
    print('Current time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))

    # instantiate model and initialize weights

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)

    model_kwargs = {
        'input_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'stride': args.stride,
        'padding': padding,
        'channels': channels,
        'alpha': args.alpha,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'resnet_size': args.resnet_size,
        'embedding_size': args.embedding_size,
        'time_dim': args.time_dim,
        'num_classes': len(train_dir.speakers),
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))

    model = create_model(args.model, **model_kwargs)
    if args.loss_type == 'asoft':
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)

    train_loader = DataLoader(train_part,
                              batch_size=args.batch_size,
                              shuffle=False,
                              **kwargs)
    veri_loader = DataLoader(veri_dir,
                             batch_size=args.batch_size,
                             shuffle=False,
                             **kwargs)
    valid_loader = DataLoader(valid_part,
                              batch_size=args.batch_size,
                              shuffle=False,
                              **kwargs)
    test_loader = DataLoader(test_dir,
                             batch_size=args.batch_size,
                             shuffle=False,
                             **kwargs)
    # sitw_test_loader = DataLoader(sitw_test_part, batch_size=args.batch_size, shuffle=False, **kwargs)
    # sitw_dev_loader = DataLoader(sitw_dev_part, batch_size=args.batch_size, shuffle=False, **kwargs)

    resume_path = args.check_path + '/checkpoint_{}.pth'
    print('=> Saving output in {}\n'.format(args.extract_path))
    epochs = np.arange(args.start_epochs, args.epochs + 1)

    for e in epochs:
        # Load model from Checkpoint file
        if os.path.isfile(resume_path.format(e)):
            print('=> loading checkpoint {}'.format(resume_path.format(e)))
            checkpoint = torch.load(resume_path.format(e))
            # epoch = checkpoint['epoch']
            if e == 0:
                filtered = checkpoint.state_dict()
            else:
                filtered = {
                    k: v
                    for k, v in checkpoint['state_dict'].items()
                    if 'num_batches_tracked' not in k
                }

            # model.load_state_dict(filtered)
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)

            try:
                args.dropout_p = model.dropout_p
            except:
                pass
        else:
            print('=> no checkpoint found at %s' % resume_path.format(e))
            continue
        model.cuda()

        file_dir = args.extract_path + '/epoch_%d' % e
        if not os.path.exists(file_dir):
            os.makedirs(file_dir)

        if not args.test_only:
            # if args.cuda:
            #     model_conv1 = model.conv1.weight.cpu().detach().numpy()
            #     np.save(file_dir + '/model.conv1.npy', model_conv1)

            train_extract(train_loader, model, file_dir, 'vox1_train')
            train_extract(valid_loader, model, file_dir, 'vox1_valid')
            test_extract(veri_loader, model, file_dir, 'vox1_veri')

        test_extract(test_loader, model, file_dir, 'vox1_test')
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers for set A: {}.'.format(args.num_spks_a))
    print('Number of Speakers for set B: {}.\n'.format(args.num_spks_b))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size,
                    'filter': args.filter, 'inst_norm': args.inst_norm, 'input_norm': args.input_norm,
                    'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim,
                    'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad,
                    'transform': args.transform, 'embedding_size': args.embedding_size, 'ince': args.inception,
                    'resnet_size': args.resnet_size, 'num_classes_a': args.num_spks_a,
                    'num_classes_b': args.num_spks_b, 'input_len': args.input_len,
                    'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p}

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            epoch = checkpoint['epoch']

            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier_a = AngleLinear(in_features=args.embedding_size, out_features=train_dir_a.num_spks, m=args.m)
        model.classifier_b = AngleLinear(in_features=args.embedding_size, out_features=train_dir_b.num_spks, m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max)

    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=int(train_dir_a.num_spks + train_dir_b.num_spks),
                                  feat_dim=args.embedding_size)
        if args.resume:
            try:
                criterion = checkpoint['criterion']
                xe_criterion.load_state_dict(criterion[1].state_dict())
            except:
                pass

    elif args.loss_type == 'gaussian':
        xe_criterion = GaussianLoss(num_classes=int(args.num_spks + args.num_spks),
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'coscenter':
        xe_criterion = CenterCosLoss(num_classes=int(args.num_spks + args.num_spks),
                                     feat_dim=args.embedding_size)
        if args.resume:
            try:
                criterion = checkpoint['criterion']
                xe_criterion.load_state_dict(criterion[1].state_dict())
            except:
                pass

    elif args.loss_type == 'mulcenter':
        xe_criterion = MultiCenterLoss(num_classes=int(args.num_spks + args.num_spks),
                                       feat_dim=args.embedding_size,
                                       num_center=args.num_center)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier_a = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=args.num_spks)
        model.classifier_b = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=args.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    elif args.loss_type == 'arcsoft':
        ce_criterion = None
        model.classifier_a = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=args.num_spks)
        model.classifier_b = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=args.num_spks)
        xe_criterion = ArcSoftmaxLoss(margin=args.margin, s=args.s)
    elif args.loss_type == 'wasse':
        xe_criterion = Wasserstein_Loss(source_cls=args.source_cls)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    # enroll_batch_size_a = int(args.batch_size)
    enroll_loader_a = torch.utils.data.DataLoader(enroll_extract_dir, batch_size=args.batch_size, shuffle=False, **kwargs)
    # batch_size_b = args.batch_size - batch_size_a
    test_loader = torch.utils.data.DataLoader(test_extract_dir, batch_size=args.batch_size, shuffle=False, **kwargs)
    # train_loader = [train_loader_a, train_loader_b]

    # train_extract_loader = torch.utils.data.DataLoader(train_extract_dir, batch_size=1, shuffle=False, **kwargs)

    # print('Batch_size is {} for A, and {} for B.'.format(batch_size_a, batch_size_b))

    # batch_size_a = int(args.batch_size / 8)
    # valid_loader_a = torch.utils.data.DataLoader(valid_dir_a, batch_size=batch_size_a, shuffle=False,
    #                                              **kwargs)
    # batch_size_b = int(len(valid_dir_b) / len(valid_dir_a) * batch_size_a)
    # valid_loader_b = torch.utils.data.DataLoader(valid_dir_b, batch_size=batch_size_b, shuffle=False,
                                                 # **kwargs)
    # valid_loader = valid_loader_a, valid_loader_b
    # test_loader = torch.utils.data.DataLoader(test_dir, batch_size=int(args.batch_size / 16), shuffle=False, **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)
    # print('Batcch_size is {} for A, and {} for B.'.format(batch_size_a, batch_size_b))
    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')
    # valid_test(train_extract_loader, valid_loader, model, epoch, xvector_dir)

    test(model, epoch, writer, xvector_dir)
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'filter': args.filter,
        'inst_norm': args.inst_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            #
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            optimizer = torch.optim.SGD(
                [{
                    'params': model.classifier.parameters(),
                    'lr': args.lr * 10
                }, {
                    'params': rest_params
                }],
                lr=args.lr,
                weight_decay=args.weight_decay,
                momentum=args.momentum)
    if args.filter:
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model.parameters())
        optimizer = torch.optim.SGD([{
            'params': model.filter_layer.parameters(),
            'lr': args.lr * 0.05
        }, {
            'params': rest_params
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.scheduler == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    else:
        milestones = args.milestones.split(',')
        milestones = [int(x) for x in milestones]
        milestones.sort()
        scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=int(args.batch_size /
                                                              2),
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)

    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    # for epoch in range(start, end):
    #     # pdb.set_trace()
    #     print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='')
    #     for param_group in optimizer.param_groups:
    #         print('{:.5f} '.format(param_group['lr']), end='')
    #     print(' \33[0m')
    #
    #     train(train_loader, model, ce, optimizer, epoch)
    #     if epoch % 4 == 1 or epoch == (end - 1):
    #         check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch)
    #         torch.save({'epoch': epoch,
    #                     'state_dict': model.state_dict(),
    #                     'criterion': ce},
    #                    check_path)
    #
    #     if epoch % 2 == 1 and epoch != (end - 1):
    #         test(test_loader, valid_loader, model, epoch)
    #     # sitw_test(sitw_test_loader, model, epoch)
    #     # sitw_test(sitw_dev_loader, model, epoch)
    #     scheduler.step()

    # exit(1)
    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')

    if args.extract:
        extract_dir = KaldiExtractDataset(dir=args.test_dir,
                                          transform=transform_V,
                                          filer_loader=file_loader)
        extract_loader = torch.utils.data.DataLoader(extract_dir,
                                                     batch_size=1,
                                                     shuffle=False,
                                                     **kwargs)
        verification_extract(extract_loader, model, xvector_dir)

    verify_dir = ScriptVerifyDataset(dir=args.test_dir,
                                     trials_file=args.trials,
                                     xvectors_dir=xvector_dir,
                                     loader=read_vec_flt)
    verify_loader = torch.utils.data.DataLoader(verify_dir,
                                                batch_size=64,
                                                shuffle=False,
                                                **kwargs)
    verification_test(test_loader=verify_loader,
                      dist_type=('cos' if args.cos_sim else 'l2'),
                      log_interval=args.log_interval,
                      save=args.save_score)

    writer.close()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    # kernel_size = args.kernel_size.split(',')
    # kernel_size = [int(x) for x in kernel_size]
    # padding = [int((x - 1) / 2) for x in kernel_size]
    #
    # kernel_size = tuple(kernel_size)
    # padding = tuple(padding)
    #
    # channels = args.channels.split(',')
    # channels = [int(x) for x in channels]

    model = ResNet20(embedding_size=args.embedding_size, num_classes=train_dir.num_spks, dropout_p=args.dropout_p)

    start_epoch = 0
    if args.save_init:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch)
        torch.save({'epoch': 0,
                    'state_dict': model.state_dict()},
                   check_path)
        # torch.save(model, check_path)
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # scheduler.load_state_dict(checkpoint['scheduler'])
            # if 'criterion' in checkpoint.keys():
            #     ce = checkpoint['criterion']
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size, out_features=train_dir.num_spks, m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size, n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{'params': xe_criterion.parameters(), 'lr': args.lr * 5},
                                     {'params': model.parameters()}],
                                    lr=args.lr, weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters())
            optimizer = torch.optim.SGD([{'params': model.classifier.parameters(), 'lr': args.lr * 5},
                                         {'params': rest_params}],
                                        lr=args.lr, weight_decay=args.weight_decay,
                                        momentum=args.momentum)

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int(args.batch_size / 2),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir, batch_size=args.test_batch_size,
                                              shuffle=False, **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)

    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(args.optimizer), end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, ce, optimizer, scheduler, epoch)
        test(test_loader, valid_loader, model, epoch)
        # sitw_test(sitw_test_loader, model, epoch)
        # sitw_test(sitw_dev_loader, model, epoch)
        scheduler.step()
        # exit(1)

    writer.close()
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    # instantiate model and initialize weights
    model_kwargs = {
        'input_dim': args.feat_dim,
        'embedding_size': args.embedding_size,
        'num_classes': len(train_dir.speakers),
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))

    model = create_model(args.model, **model_kwargs)

    if args.cuda:
        model.cuda()

    start = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model.load_state_dict(filtered)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            optimizer = torch.optim.SGD(
                [{
                    'params': model.classifier.parameters(),
                    'lr': args.lr * 5
                }, {
                    'params': rest_params
                }],
                lr=args.lr,
                weight_decay=args.weight_decay,
                momentum=args.momentum)

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    # print('Scheduler options: {}'.format(milestones))
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start)
        torch.save(
            {
                'epoch': start,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, check_path)

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = args.epochs + 1

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    ce = [ce_criterion, xe_criterion]
    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(
            args.optimizer),
              end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, optimizer, ce, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break

    writer.close()
Exemple #8
0
    def __init__(self,
                 embedding_size,
                 layers=[1, 1, 1, 0],
                 block=BasicBlock,
                 n_classes=1000,
                 m=3):  # block类型,embedding大小,分类数,maigin大小
        super(SuperficialResCNN, self).__init__()

        self.embedding_size = embedding_size
        self.relu = ReLU(inplace=True)

        self.in_planes = 64
        self.conv1 = nn.Conv2d(1,
                               64,
                               kernel_size=5,
                               stride=2,
                               padding=2,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, layers[0])

        self.in_planes = 128
        self.conv2 = nn.Conv2d(64,
                               128,
                               kernel_size=5,
                               stride=2,
                               padding=2,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.layer2 = self._make_layer(block, 128, layers[1])

        self.in_planes = 256
        self.conv3 = nn.Conv2d(128,
                               256,
                               kernel_size=5,
                               stride=2,
                               padding=2,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.layer3 = self._make_layer(block, 256, layers[2])

        # self.in_planes = 512
        # self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2, bias=False)
        # self.bn4 = nn.BatchNorm2d(512)
        # self.layer4 = self._make_layer(block, 512, layers[3])

        # self.avg_pool = nn.AdaptiveAvgPool2d([4, 1])
        self.avg_pool = nn.AdaptiveAvgPool2d((4, 1))

        self.fc = nn.Sequential(nn.Linear(self.in_planes * 4, embedding_size),
                                nn.BatchNorm1d(embedding_size))

        # self.W = torch.nn.Parameter(torch.randn(self.embedding_size, n_classes))
        # self.W.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
        # nn.init.xavier_normal(self.W, gain=1)

        self.angle_linear = AngleLinear(in_features=embedding_size,
                                        out_features=n_classes,
                                        m=m)

        for m in self.modules():  # 对于各层参数的初始化
            if isinstance(m, nn.Conv2d):  # 以2/n的开方为标准差,做均值为0的正态分布
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):  # weight设置为1,bias为0
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):  # weight设置为1,bias为0
                m.weight.data.fill_(1)
                m.bias.data.zero_()
Exemple #9
0
        'num_classes_a': train_dir_a.num_spks,
        'num_classes_b': train_dir_b.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    dist_type = 'cos' if args.cos_sim else 'l2'
    print('Testing with %s distance, ' % dist_type)

    if args.valid or args.extract:
        model = create_model(args.model, **model_kwargs)
        if args.loss_type == 'asoft':
            model.classifier_a = AngleLinear(in_features=args.embedding_size,
                                             out_features=train_dir_a.num_spks,
                                             m=args.m)
            model.classifier_b = AngleLinear(in_features=args.embedding_size,
                                             out_features=train_dir_b.num_spks,
                                             m=args.m)
        elif args.loss_type in ['amsoft', 'arcsoft']:
            model.classifier_a = AdditiveMarginLinear(
                feat_dim=args.embedding_size, n_classes=train_dir_a.num_spks)
            model.classifier_b = AdditiveMarginLinear(
                feat_dim=args.embedding_size, n_classes=train_dir_b.num_spks)

        assert os.path.isfile(args.resume)
        print('=> loading checkpoint {}'.format(args.resume))
        checkpoint = torch.load(args.resume)
        # start_epoch = checkpoint['epoch']
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    # print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    padding = [int((x - 1) / 2) for x in kernel_size]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'embedding_size': args.embedding_size,
        'resnet_size': args.resnet_size,
        'input_dim': args.feat_dim,
        'num_classes': train_dir.num_spks,
        'alpha': args.alpha,
        'channels': channels,
        'stride': args.stride,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'kernel_size': kernel_size,
        'padding': padding,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    if args.valid or args.extract:
        model = create_model(args.model, **model_kwargs)
        if args.loss_type == 'asoft':
            model.classifier = AngleLinear(in_features=args.embedding_size,
                                           out_features=train_dir.num_spks,
                                           m=args.m)
        elif args.loss_type == 'amsoft':
            model.classifier = AdditiveMarginLinear(
                feat_dim=args.embedding_size, n_classes=train_dir.num_spks)

        assert os.path.isfile(args.resume)
        print('=> loading checkpoint {}'.format(args.resume))
        checkpoint = torch.load(args.resume)
        # start_epoch = checkpoint['epoch']

        filtered = {
            k: v
            for k, v in checkpoint['state_dict'].items()
            if 'num_batches_tracked' not in k
        }
        # model_dict = model.state_dict()
        # model_dict.update(filtered)
        model.load_state_dict(filtered)
        #
        try:
            model.dropout.p = args.dropout_p
        except:
            pass
        start = args.start_epoch
        print('Epoch is : ' + str(start))

        if args.cuda:
            model.cuda()
        # train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=True, **kwargs)
        if args.valid:
            valid_loader = torch.utils.data.DataLoader(
                valid_dir,
                batch_size=args.test_batch_size,
                shuffle=False,
                **kwargs)
            valid(valid_loader, model)

        if args.extract:
            verify_loader = torch.utils.data.DataLoader(
                verfify_dir,
                batch_size=args.test_batch_size,
                shuffle=False,
                **kwargs)
            extract(verify_loader, model, args.xvector_dir)

    file_loader = read_vec_flt
    test_dir = ScriptVerifyDataset(dir=args.test_dir,
                                   trials_file=args.trials,
                                   xvectors_dir=args.xvector_dir,
                                   loader=file_loader)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size *
                                              64,
                                              shuffle=False,
                                              **kwargs)
    test(test_loader)
Exemple #11
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers for set A: {}.'.format(train_dir_a.num_spks))
    print('Number of Speakers for set B: {}.\n'.format(train_dir_b.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes_a': train_dir_a.num_spks,
        'num_classes_b': train_dir_b.num_spks,
        'input_len': args.input_len,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            # filtered = {k: v for k, v in checkpoint['state_dict'] if 'num_batches_tracked' not in k}
            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }
            model_dict = model.state_dict()
            model_dict.update(filtered)
            model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier_a = AngleLinear(in_features=args.embedding_size,
                                         out_features=train_dir_a.num_spks,
                                         m=args.m)
        model.classifier_b = AngleLinear(in_features=args.embedding_size,
                                         out_features=train_dir_b.num_spks,
                                         m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)

    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=int(train_dir_a.num_spks +
                                                  train_dir_b.num_spks),
                                  feat_dim=args.embedding_size)
        if args.resume:
            try:
                criterion = checkpoint['criterion']
                xe_criterion.load_state_dict(criterion[1].state_dict())
            except:
                pass

    elif args.loss_type == 'gaussian':
        xe_criterion = GaussianLoss(num_classes=int(train_dir_a.num_spks +
                                                    train_dir_b.num_spks),
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'coscenter':
        xe_criterion = CenterCosLoss(num_classes=int(train_dir_a.num_spks +
                                                     train_dir_b.num_spks),
                                     feat_dim=args.embedding_size)
        if args.resume:
            try:
                criterion = checkpoint['criterion']
                xe_criterion.load_state_dict(criterion[1].state_dict())
            except:
                pass

    elif args.loss_type == 'mulcenter':
        xe_criterion = MultiCenterLoss(num_classes=int(train_dir_a.num_spks +
                                                       train_dir_b.num_spks),
                                       feat_dim=args.embedding_size,
                                       num_center=args.num_center)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier_a = AdditiveMarginLinear(
            feat_dim=args.embedding_size, n_classes=train_dir_a.num_spks)
        model.classifier_b = AdditiveMarginLinear(
            feat_dim=args.embedding_size, n_classes=train_dir_b.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    elif args.loss_type == 'arcsoft':
        ce_criterion = None
        model.classifier_a = AdditiveMarginLinear(
            feat_dim=args.embedding_size, n_classes=train_dir_a.num_spks)
        model.classifier_b = AdditiveMarginLinear(
            feat_dim=args.embedding_size, n_classes=train_dir_b.num_spks)
        xe_criterion = ArcSoftmaxLoss(margin=args.margin, s=args.s)
    elif args.loss_type == 'wasse':
        xe_criterion = Wasserstein_Loss(source_cls=args.source_cls)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type in ['center', 'mulcenter', 'gaussian', 'coscenter']:
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.filter == 'fDLR':
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model.parameters())
        optimizer = torch.optim.SGD([{
            'params': model.filter_layer.parameters(),
            'lr': args.lr * 0.05
        }, {
            'params': rest_params
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    # Save model config txt
    with open(
            osp.join(
                args.check_path,
                'model.%s.cfg' % time.strftime("%Y.%m.%d", time.localtime())),
            'w') as f:
        f.write('model: ' + str(model) + '\n')
        f.write('CrossEntropy: ' + str(ce_criterion) + '\n')
        f.write('Other Loss: ' + str(xe_criterion) + '\n')
        f.write('Optimizer: ' + str(optimizer) + '\n')

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    if args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer,
                                               gamma=args.gamma,
                                               verbose=True)
    elif args.scheduler == 'rop':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=args.patience,
                                                   min_lr=1e-5,
                                                   verbose=True)
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=milestones,
                                             gamma=0.1,
                                             verbose=True)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    batch_size_a = int(args.batch_size * len(train_dir_a) /
                       (len(train_dir_a) + len(train_dir_b)))
    train_loader_a = torch.utils.data.DataLoader(train_dir_a,
                                                 batch_size=batch_size_a,
                                                 shuffle=False,
                                                 **kwargs)

    # num_iteration = np.floor(len(train_dir_a) / args.batch_size)
    batch_size_b = args.batch_size - batch_size_a
    train_loader_b = torch.utils.data.DataLoader(train_dir_b,
                                                 batch_size=batch_size_b,
                                                 shuffle=False,
                                                 **kwargs)
    train_loader = [train_loader_a, train_loader_b]

    train_extract_loader = torch.utils.data.DataLoader(train_extract_dir,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       **kwargs)

    print('Batch_size is {} for A, and {} for B.'.format(
        batch_size_a, batch_size_b))

    batch_size_a = int(args.batch_size / 8)
    valid_loader_a = torch.utils.data.DataLoader(valid_dir_a,
                                                 batch_size=batch_size_a,
                                                 shuffle=False,
                                                 **kwargs)

    batch_size_b = int(len(valid_dir_b) / len(valid_dir_a) * batch_size_a)
    valid_loader_b = torch.utils.data.DataLoader(valid_dir_b,
                                                 batch_size=batch_size_b,
                                                 shuffle=False,
                                                 **kwargs)
    valid_loader = valid_loader_a, valid_loader_b

    # test_loader = torch.utils.data.DataLoader(test_dir, batch_size=int(args.batch_size / 16), shuffle=False, **kwargs)
    # sitw_test_loader = torch.utils.data.DataLoader(sitw_test_dir, batch_size=args.test_batch_size,
    #                                                shuffle=False, **kwargs)
    # sitw_dev_loader = torch.utils.data.DataLoader(sitw_dev_part, batch_size=args.test_batch_size, shuffle=False,
    #                                               **kwargs)
    # print('Batcch_size is {} for A, and {} for B.'.format(batch_size_a, batch_size_b))
    if args.cuda:
        if len(args.gpu_id) > 1:
            print("Continue with gpu: %s ..." % str(args.gpu_id))
            torch.distributed.init_process_group(
                backend="nccl",
                # init_method='tcp://localhost:23456',
                init_method=
                'file:///home/work2020/yangwenhao/project/lstm_speaker_verification/data/sharedfile',
                rank=0,
                world_size=1)
            model = model.cuda()
            model = DistributedDataParallel(model, find_unused_parameters=True)

        else:
            model = model.cuda()

        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')

    start_time = time.time()
    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(
            args.optimizer),
              end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')
        # pdb.set_trace()
        train(train_loader, model, ce, optimizer, epoch)
        valid_loss = valid_class(valid_loader, model, ce, epoch)

        if epoch % 4 == 1 or epoch == (end - 1) or epoch in milestones:
            check_path = '{}/checkpoint_{}.pth'.format(args.check_path, epoch)
            model_state_dict = model.module.state_dict() \
                                   if isinstance(model, DistributedDataParallel) else model.state_dict(),

            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model_state_dict,
                    'criterion': ce
                }, check_path)

        if epoch % 2 == 1 or epoch == (end - 1):
            valid_test(train_extract_loader, model, epoch, xvector_dir)

        if epoch != (end - 2) and (epoch % 4 == 1 or epoch in milestones
                                   or epoch == (end - 1)):
            test(model, epoch, writer, xvector_dir)

        if args.scheduler == 'rop':
            scheduler.step(valid_loss)
        else:
            scheduler.step()

        # exit(1)
    writer.close()
    stop_time = time.time()
    t = float(start_time - stop_time)
    print("Running %.4f minutes for each epoch.\n" % (t / 60 / (end - start)))
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]

    context = args.context.split(',')
    context = [int(x) for x in context]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'context': context,
        'filter_fix': args.filter_fix,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'exp': args.exp,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))
    dist_type = 'cos' if args.cos_sim else 'l2'
    print('Testing with %s distance, ' % dist_type)

    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    iteration = 0  # if args.resume else 0
    if args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'gaussian':
        xe_criterion = GaussianLoss(num_classes=train_dir.num_spks,
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'coscenter':
        xe_criterion = CenterCosLoss(num_classes=train_dir.num_spks,
                                     feat_dim=args.embedding_size)
    elif args.loss_type == 'mulcenter':
        xe_criterion = MultiCenterLoss(num_classes=train_dir.num_spks,
                                       feat_dim=args.embedding_size,
                                       num_center=args.num_center)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)
    elif args.loss_type == 'arcsoft':
        ce_criterion = None
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = ArcSoftmaxLoss(margin=args.margin,
                                      s=args.s,
                                      iteraion=iteration,
                                      all_iteraion=args.all_iteraion)
    elif args.loss_type == 'wasse':
        xe_criterion = Wasserstein_Loss(source_cls=args.source_cls)
    elif args.loss_type == 'ring':
        xe_criterion = RingLoss(ring=args.ring)
        args.alpha = 0.0

    model_para = model.parameters()
    if args.loss_type in [
            'center', 'mulcenter', 'gaussian', 'coscenter', 'ring'
    ]:
        assert args.lr_ratio > 0
        model_para = [{
            'params': xe_criterion.parameters(),
            'lr': args.lr * args.lr_ratio
        }, {
            'params': model.parameters()
        }]
    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            assert args.lr_ratio > 0
            model_para = [{
                'params': model.classifier.parameters(),
                'lr': args.lr * args.lr_ratio
            }, {
                'params': rest_params
            }]

    if args.filter in ['fDLR', 'fBLayer', 'fLLayer', 'fBPLayer']:
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model.parameters())
        model_para = [{
            'params': model.filter_layer.parameters(),
            'lr': args.lr * args.lr_ratio
        }, {
            'params': rest_params
        }]

    optimizer = create_optimizer(model_para, args.optimizer, **opt_kwargs)

    if not args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }

            # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    # Save model config txt
    with open(
            osp.join(
                args.check_path,
                'model.%s.conf' % time.strftime("%Y.%m.%d", time.localtime())),
            'w') as f:
        f.write('model: ' + str(model) + '\n')
        f.write('CrossEntropy: ' + str(ce_criterion) + '\n')
        f.write('Other Loss: ' + str(xe_criterion) + '\n')
        f.write('Optimizer: ' + str(optimizer) + '\n')

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    if args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == 'rop':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=args.patience,
                                                   min_lr=1e-5)
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=milestones,
                                             gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(
        train_dir,
        batch_size=args.batch_size,
        collate_fn=PadCollate(dim=2,
                              num_batch=int(
                                  np.ceil(len(train_dir) / args.batch_size)),
                              min_chunk_size=args.min_chunk_size,
                              max_chunk_size=args.max_chunk_size),
        shuffle=args.shuffle,
        **kwargs)

    valid_loader = torch.utils.data.DataLoader(
        valid_dir,
        batch_size=int(args.batch_size / 2),
        collate_fn=PadCollate(dim=2,
                              fix_len=True,
                              min_chunk_size=args.chunk_size,
                              max_chunk_size=args.chunk_size + 1),
        shuffle=False,
        **kwargs)
    train_extract_loader = torch.utils.data.DataLoader(train_extract_dir,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       **extract_kwargs)

    if args.cuda:
        if len(args.gpu_id) > 1:
            print("Continue with gpu: %s ..." % str(args.gpu_id))
            torch.distributed.init_process_group(
                backend="nccl",
                # init_method='tcp://localhost:23456',
                init_method=
                'file:///home/ssd2020/yangwenhao/lstm_speaker_verification/data/sharedfile',
                rank=0,
                world_size=1)
            model = DistributedDataParallel(model.cuda(),
                                            find_unused_parameters=True)

        else:
            model = model.cuda()

        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')

    start_time = time.time()
    try:

        for epoch in range(start, end):
            # pdb.set_trace()
            lr_string = '\n\33[1;34m Current \'{}\' learning rate is '.format(
                args.optimizer)
            for param_group in optimizer.param_groups:
                lr_string += '{:.6f} '.format(param_group['lr'])
            print('%s \33[0m' % lr_string)

            train(train_loader, model, ce, optimizer, epoch)
            valid_loss = valid_class(valid_loader, model, ce, epoch)

            if (epoch == 1 or epoch !=
                (end - 2)) and (epoch % 4 == 1 or epoch in milestones
                                or epoch == (end - 1)):
                model.eval()
                check_path = '{}/checkpoint_{}.pth'.format(
                    args.check_path, epoch)
                model_state_dict = model.module.state_dict() \
                                       if isinstance(model, DistributedDataParallel) else model.state_dict(),
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model_state_dict,
                        'criterion': ce
                    }, check_path)

                valid_test(train_extract_loader, model, epoch, xvector_dir)
                test(model, epoch, writer, xvector_dir)
                if epoch != (end - 1):
                    try:
                        shutil.rmtree("%s/train/epoch_%s" %
                                      (xvector_dir, epoch))
                        shutil.rmtree("%s/test/epoch_%s" %
                                      (xvector_dir, epoch))
                    except Exception as e:
                        print('rm dir xvectors error:', e)

            if args.scheduler == 'rop':
                scheduler.step(valid_loss)
            else:
                scheduler.step()

    except KeyboardInterrupt:
        end = epoch

    writer.close()
    stop_time = time.time()
    t = float(stop_time - start_time)
    print("Running %.4f minutes for each epoch.\n" % (t / 60 /
                                                      (max(end - start, 1))))
    exit(0)
def main():
    print('\nNumber of Speakers: {}.'.format(train_dir.num_spks))
    # print the experiment configuration
    print('Current time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))

    model = create_model(args.model, **model_kwargs)
    if args.loss_type == 'asoft':
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
    elif args.loss_type == 'amsoft' or args.loss_type == 'arcsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)

    train_loader = DataLoader(train_part,
                              batch_size=args.batch_size,
                              shuffle=False,
                              **kwargs)
    veri_loader = DataLoader(veri_dir,
                             batch_size=args.batch_size,
                             shuffle=False,
                             **kwargs)
    valid_loader = DataLoader(valid_part,
                              batch_size=args.batch_size,
                              shuffle=False,
                              **kwargs)
    test_loader = DataLoader(test_dir,
                             batch_size=args.batch_size,
                             shuffle=False,
                             **kwargs)
    # sitw_test_loader = DataLoader(sitw_test_part, batch_size=args.batch_size, shuffle=False, **kwargs)
    # sitw_dev_loader = DataLoader(sitw_dev_part, batch_size=args.batch_size, shuffle=False, **kwargs)

    resume_path = args.check_path + '/checkpoint_{}.pth'
    print('=> Saving output in {}\n'.format(args.extract_path))
    epochs = np.arange(args.start_epochs, args.epochs + 1)

    for e in epochs:
        # Load model from Checkpoint file
        if os.path.isfile(resume_path.format(e)):
            print('=> loading checkpoint {}'.format(resume_path.format(e)))
            checkpoint = torch.load(resume_path.format(e))
            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]

            # epoch = checkpoint['epoch']
            # if e == 0:
            #     filtered = checkpoint.state_dict()
            # else:
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)

        else:
            print('=> no checkpoint found at %s' % resume_path.format(e))
            continue
        model.cuda()

        file_dir = args.extract_path + '/epoch_%d' % e
        if not os.path.exists(file_dir):
            os.makedirs(file_dir)

        if not args.test_only:
            # if args.cuda:
            #     model_conv1 = model.conv1.weight.cpu().detach().numpy()
            #     np.save(file_dir + '/model.conv1.npy', model_conv1)

            train_extract(train_loader, model, file_dir,
                          '%s_train' % args.train_set_name)
            train_extract(valid_loader, model, file_dir,
                          '%s_valid' % args.train_set_name)
            test_extract(veri_loader, model, file_dir,
                         '%s_veri' % args.train_set_name)

        test_extract(test_loader, model, file_dir,
                     '%s_test' % args.test_set_name)