Example #1
0
def getCriterion(args, downsampling, nSpeakers, nPhones):
    dimFeatures = args.hiddenGar if not args.onEncoder else args.hiddenEncoder
    if not args.supervised:
        if args.cpc_mode == 'none':
            cpcCriterion = cr.NoneCriterion()
        else:
            sizeInputSeq = (args.sizeWindow // downsampling)
            cpcCriterion = cr.CPCUnsupersivedCriterion(
                args.nPredicts,
                args.hiddenGar,
                args.hiddenEncoder,
                args.negativeSamplingExt,
                mode=args.cpc_mode,
                rnnMode=args.rnnMode,
                dropout=args.dropout,
                nSpeakers=nSpeakers,
                speakerEmbedding=args.speakerEmbedding,
                sizeInputSeq=sizeInputSeq)
    elif args.pathPhone is not None:
        if not args.CTC:
            cpcCriterion = cr.PhoneCriterion(dimFeatures,
                                             nPhones,
                                             args.onEncoder,
                                             nLayers=args.nLevelsPhone)
        else:
            cpcCriterion = cr.CTCPhoneCriterion(dimFeatures, nPhones,
                                                args.onEncoder)
    else:
        cpcCriterion = cr.SpeakerCriterion(dimFeatures, nSpeakers)
    return cpcCriterion
Example #2
0
def getCriterion(args, downsampling, nSpeakers, nPhones):
    dimFeatures = args.hiddenGar if not args.onEncoder else args.hiddenEncoder
    if not args.supervised:
        if args.cpc_mode == 'none':
            cpcCriterion = cr.NoneCriterion()
        else:
            sizeInputSeq = (args.sizeWindow // downsampling)
            if args.CPCCTC:
                cpcCriterion = sa.CPCUnsupersivedCriterion(args.nPredicts,
                                                        args.CPCCTCNumMatched,
                                                        args.hiddenGar,
                                                        args.hiddenEncoder,
                                                        args.negativeSamplingExt,
                                                        allowed_skips_beg=args.CPCCTCSkipBeg,
                                                        allowed_skips_end=args.CPCCTCSkipEnd,
                                                        predict_self_loop=args.CPCCTCSelfLoop,
                                                        learn_blank=args.CPCCTCLearnBlank,
                                                        normalize_enc=args.CPCCTCNormalizeEncs,
                                                        normalize_preds=args.CPCCTCNormalizePreds,
                                                        masq_rules=args.CPCCTCMasq,
                                                        loss_temp=args.CPCCTCLossTemp,
                                                        no_negs_in_match_window=args.CPCCTCNoNegsMatchWin,
                                                        limit_negs_in_batch=args.limitNegsInBatch,
                                                        mode=args.cpc_mode,
                                                        rnnMode=args.rnnMode,
                                                        dropout=args.dropout,
                                                        nSpeakers=nSpeakers,
                                                        speakerEmbedding=args.speakerEmbedding,
                                                        sizeInputSeq=sizeInputSeq)

            else:
                cpcCriterion = cr.CPCUnsupersivedCriterion(args.nPredicts,
                                                        args.hiddenGar,
                                                        args.hiddenEncoder,
                                                        args.negativeSamplingExt,
                                                        mode=args.cpc_mode,
                                                        rnnMode=args.rnnMode,
                                                        dropout=args.dropout,
                                                        nSpeakers=nSpeakers,
                                                        speakerEmbedding=args.speakerEmbedding,
                                                        sizeInputSeq=sizeInputSeq)
    elif args.pathPhone is not None:
        if not args.CTC:
            cpcCriterion = cr.PhoneCriterion(dimFeatures,
                                             nPhones, args.onEncoder,
                                             nLayers=args.nLevelsPhone)
        else:
            cpcCriterion = cr.CTCPhoneCriterion(dimFeatures,
                                                nPhones, args.onEncoder)
    else:
        cpcCriterion = cr.SpeakerCriterion(dimFeatures, nSpeakers)
    return cpcCriterion
Example #3
0
            def constructPhoneCriterionAndOptimizer():
                if not args.CTCphones:
                    # print(f"Running phone separability with aligned phones")
                    phone_criterion = cr.PhoneCriterion(dim_features,
                                                nPhonesInData, args.phone_get_encoded,
                                                nLayers=args.linsep_net_layers)
                else:
                    # print(f"Running phone separability with CTC loss")
                    phone_criterion = cr.CTCPhoneCriterion(dim_features,
                                                    nPhonesInData, args.phone_get_encoded,
                                                    nLayers=args.linsep_net_layers)
                phone_criterion.cuda()
                phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU))

                # Optimizer
                phone_g_params = list(phone_criterion.parameters())

                phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)
                
                return phone_criterion, phone_optimizer
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if args.model == "cpc":
        def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True):
            """
            Load CPC Feature Maker from CPC checkpoint file.
            """
            # Set LSTM level
            if gru_level is not None and gru_level > 0:
                updateConfig = argparse.Namespace(nLevelsGRU=gru_level)
            else:
                updateConfig = None

            # Load CPC model
            model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig)
            
            # Keep hidden units at LSTM layers on sequential batches
            model.gAR.keepHidden = keep_hidden

            # Build CPC Feature Maker from CPC model
            #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded)

            #return featureMaker
            return model, nHiddenGar, nHiddenEncoder

        if args.gru_level is not None and args.gru_level > 0:
            model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level)
        else:
            model, hidden_gar, hidden_encoder = fl.loadModel(args.load,
                                                     loadStateDict=not args.no_pretraining)

        dim_features = hidden_encoder if args.get_encoded else hidden_gar
    else:
        sys.path.append(os.path.abspath(args.path_fairseq))
        from fairseq import checkpoint_utils

        def loadCheckpoint(path_checkpoint, path_data):
            """
            Load lstm_lm model from checkpoint.
            """
            # Set up the args Namespace
            model_args = argparse.Namespace(
                task="language_modeling",
                output_dictionary_size=-1,
                data=path_data,
                path=path_checkpoint
                )
            
            # Load model
            models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path])
            model = models[0]
            return model

        model = loadCheckpoint(args.load[0], args.pathDB)
        dim_features = 768

    dim_inter = args.dim_inter
    # Now the criterion

    if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace":
        speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"])
        for param in speakers_factorized.parameters():
            param.requires_grad = False

        def my_nullspace(At, rcond=None):
            ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True)
            vht=vht.T        
            Mt, Nt = ut.shape[0], vht.shape[1] 
            if rcond is None:
                rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt)
            tolt = torch.max(st) * rcondt
            numt= torch.sum(st > tolt, dtype=int)
            nullspace = vht[numt:,:].T.cpu().conj()
            # nullspace.backward(torch.ones_like(nullspace),retain_graph=True)
            return nullspace

        dim_features = dim_features - dim_inter
        nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight)
        model = CPCModelNullspace(model, nullspace)

    phone_labels = None
    if args.pathPhone is not None:

        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        label_key = 'phone'

        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features,
                                          n_phones, args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features,
                                             n_phones, args.get_encoded)
    else:
        label_key = 'speaker'
        print(f"Running speaker separability")
        if args.mode == "speakers_factorized":
            criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        else:
            criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers), nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers), nProcessLoader=args.n_process_loader)

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size, "uniform", True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size, 'sequential', False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params, lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    if args.centerpushFile:
        clustersFileExt = args.centerpushFile.split('.')[-1]
        assert clustersFileExt in ('pt', 'npy', 'txt')
        if clustersFileExt == 'npy':
            centers = np.load(args.centerpushFile)
        elif clustersFileExt == 'txt':
            centers = np.genfromtxt(args.centerpushFile)
        elif clustersFileExt == 'pt':  # assuming it's a checkpoint
            centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck']
            centers = torch.reshape(centers, centers.shape[1:]).numpy()
        centers = torch.tensor(centers).cuda()
        centerpushSettings = (centers, args.centerpushDeg)
    else:
        centerpushSettings = None

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
Example #5
0
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    model, hidden_gar, hidden_encoder = fl.loadModel(
        args.load, loadStateDict=not args.no_pretraining)
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    dim_features = hidden_encoder if args.get_encoded else hidden_gar

    # Now the criterion
    phone_labels = None
    if args.pathPhone is not None:
        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features, n_phones,
                                          args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features, n_phones,
                                             args.get_encoded)
    else:
        print(f"Running speaker separability")
        criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers))
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers))

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size,
                                          "uniform",
                                          True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size,
                                      'sequential',
                                      False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params,
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint)