def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if args.model == "cpc":
        def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True):
            """
            Load CPC Feature Maker from CPC checkpoint file.
            """
            # Set LSTM level
            if gru_level is not None and gru_level > 0:
                updateConfig = argparse.Namespace(nLevelsGRU=gru_level)
            else:
                updateConfig = None

            # Load CPC model
            model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig)
            
            # Keep hidden units at LSTM layers on sequential batches
            model.gAR.keepHidden = keep_hidden

            # Build CPC Feature Maker from CPC model
            #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded)

            #return featureMaker
            return model, nHiddenGar, nHiddenEncoder

        if args.gru_level is not None and args.gru_level > 0:
            model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level)
        else:
            model, hidden_gar, hidden_encoder = fl.loadModel(args.load,
                                                     loadStateDict=not args.no_pretraining)

        dim_features = hidden_encoder if args.get_encoded else hidden_gar
    else:
        sys.path.append(os.path.abspath(args.path_fairseq))
        from fairseq import checkpoint_utils

        def loadCheckpoint(path_checkpoint, path_data):
            """
            Load lstm_lm model from checkpoint.
            """
            # Set up the args Namespace
            model_args = argparse.Namespace(
                task="language_modeling",
                output_dictionary_size=-1,
                data=path_data,
                path=path_checkpoint
                )
            
            # Load model
            models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path])
            model = models[0]
            return model

        model = loadCheckpoint(args.load[0], args.pathDB)
        dim_features = 768

    dim_inter = args.dim_inter
    # Now the criterion

    if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace":
        speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"])
        for param in speakers_factorized.parameters():
            param.requires_grad = False

        def my_nullspace(At, rcond=None):
            ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True)
            vht=vht.T        
            Mt, Nt = ut.shape[0], vht.shape[1] 
            if rcond is None:
                rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt)
            tolt = torch.max(st) * rcondt
            numt= torch.sum(st > tolt, dtype=int)
            nullspace = vht[numt:,:].T.cpu().conj()
            # nullspace.backward(torch.ones_like(nullspace),retain_graph=True)
            return nullspace

        dim_features = dim_features - dim_inter
        nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight)
        model = CPCModelNullspace(model, nullspace)

    phone_labels = None
    if args.pathPhone is not None:

        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        label_key = 'phone'

        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features,
                                          n_phones, args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features,
                                             n_phones, args.get_encoded)
    else:
        label_key = 'speaker'
        print(f"Running speaker separability")
        if args.mode == "speakers_factorized":
            criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        else:
            criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers), nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers), nProcessLoader=args.n_process_loader)

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size, "uniform", True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size, 'sequential', False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params, lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    if args.centerpushFile:
        clustersFileExt = args.centerpushFile.split('.')[-1]
        assert clustersFileExt in ('pt', 'npy', 'txt')
        if clustersFileExt == 'npy':
            centers = np.load(args.centerpushFile)
        elif clustersFileExt == 'txt':
            centers = np.genfromtxt(args.centerpushFile)
        elif clustersFileExt == 'pt':  # assuming it's a checkpoint
            centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck']
            centers = torch.reshape(centers, centers.shape[1:]).numpy()
        centers = torch.tensor(centers).cuda()
        centerpushSettings = (centers, args.centerpushDeg)
    else:
        centerpushSettings = None

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
Beispiel #2
0
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    model, hidden_gar, hidden_encoder = fl.loadModel(
        args.load, loadStateDict=not args.no_pretraining)
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    dim_features = hidden_encoder if args.get_encoded else hidden_gar

    # Now the criterion
    phone_labels = None
    if args.pathPhone is not None:
        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features, n_phones,
                                          args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features, n_phones,
                                             args.get_encoded)
    else:
        print(f"Running speaker separability")
        criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers))
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers))

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size,
                                          "uniform",
                                          True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size,
                                      'sequential',
                                      False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params,
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint)
Beispiel #3
0
def main(args):

    # import ptvsd
    # ptvsd.enable_attach(('0.0.0.0', 7309))
    # print("Attach debugger now")
    # ptvsd.wait_for_attach()

    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    os.makedirs(args.pathCheckpoint, exist_ok=True)
    if not args.onlyCapture and not args.only_classif_metric:
        json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt'))
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args, locArgs,
                        forbiddenAttr={"nGPU", "pathCheckpoint",
                                       "debug", "restart", "world_size",
                                       "n_nodes", "node_id", "n_gpu_per_node",
                                       "max_size_loaded"})
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if not args.onlyCapture or args.only_classif_metric:
        print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
        # Datasets
        if args.pathTrain is not None:
            seqTrain = filterSeqs(args.pathTrain, seqNames)
        else:
            seqTrain = seqNames

        if args.pathVal is None:
            random.shuffle(seqTrain)
            sizeTrain = int(0.99 * len(seqTrain))
            seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
            print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
        else:
            seqVal = filterSeqs(args.pathVal, seqNames)

    if args.pathCaptureDS is not None:
        assert args.pathCaptureSave is not None
        whatToSave = []
        if args.captureEverything:
            whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred']
            if args.path_phone_data:
                whatToSave.append('phone_align')
            if args.CPCCTC:
                whatToSave.append('cpcctc_align')
                whatToSave.append('cpcctc_log_scores')
        else:
            for argVal, name in zip([args.captureConvRepr, 
                                    args.captureCtxRepr, 
                                    args.captureSpeakerAlign, 
                                    args.capturePhoneAlign,
                                    args.capturePred,
                                    args.captureCPCCTCalign,
                                    args.captureCPCCTClogScores], 
                                    ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']):
                if argVal:
                    whatToSave.append(name)
        ###assert len(whatToSave) > 0
        captureOptions = {
            'path': args.pathCaptureSave,
            'eachEpochs': args.captureEachEpochs,
            'what': whatToSave
        }
        seqCapture = filterSeqs(args.pathCaptureDS, seqNames, 
                                percentage=args.captureDSfreq, totalNum=args.captureDStotNr)
        print(f'Capture files: {len(seqCapture)}')
    else:
        seqCapture = None
        captureOptions = None

    if not args.onlyCapture:
        if args.debug:
            seqTrain = seqTrain[-1000:]
            seqVal = seqVal[-100:]

        phoneLabels, nPhones = None, None
        if args.supervised and args.pathPhone is not None:
            print("Loading the phone labels at " + args.pathPhone)
            phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
            print(f"{nPhones} phones found")

        print("")
        print(f'Loading audio data at {args.pathDB}')
        print("Loading the training dataset")
        trainDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqTrain,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader,
                                    MAX_SIZE_LOADED=args.max_size_loaded)
        print("Training dataset loaded")
        print("")

        print("Loading the validation dataset")
        valDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqVal,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Validation dataset loaded")
        print("")
    else:
        phoneLabels, nPhones = None, None
        trainDataset = None
        valDataset = None

    if seqCapture is not None:

        if args.path_phone_data:
            print("Loading the phone labels at " + args.path_phone_data)
            phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data)
        else:
            assert not args.capturePhoneAlign
            phoneLabelsForCapture = None
            
        print("Loading the capture dataset")
        captureDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqCapture,
                                    phoneLabelsForCapture,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Capture dataset loaded")
        print("")

        if args.captureSetStats:
            captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats)
        else:
            captureSetStatsCollector = None
    else:
        captureDataset = None
        captureSetStatsCollector = None

    if args.load is not None:
        if args.gru_level is not None and args.gru_level > 0:
            updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
        else:
            updateConfig = None


        # loadBestNotLast = args.onlyCapture or args.only_classif_metric
        # could use this option for loading best state when not running actual training
        # but relying on CPC internal acc isn't very reliable
        # [!] caution - because of how they capture checkpoints,
        #     they capture "best in this part of training" as "best" (apart from capturing current state)
        #     so if best is in epoch 100 and training is paused and resumed from checkpoint
        #     in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best
        #     (but this is internal-CPC-score best anyway, which is quite vague)
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig)
        CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder            

        if args.gru_level is not None and args.gru_level > 0:
            # Keep hidden units at LSTM layers on sequential batches
            if args.nullspace:
                cpcModel.cpc.gAR.keepHidden = True
            else:
                cpcModel.gAR.keepHidden = True

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

        CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput()

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING
    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],  downsampling,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, downsampling,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()
    
    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params, lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer and not args.onlyCapture and not args.only_classif_metric:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")
        with open(args.pathCheckpoint + "_args.json", 'w') as file:
            json.dump(vars(args), file, indent=2)

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                           lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                                                               n_epoch, epoch),
                                                           last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        print(f'Redoing {len(logs["epoch"])} scheduler steps')
        for i in range(len(logs["epoch"])):
            scheduler.step()

    print("cpcModel", cpcModel)
    print("cpcCriterion", cpcCriterion)

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()
    
    if args.supervised_classif_metric:

        linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU

        dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar
        dim_ctx_features = CPChiddenGar  # for speakers using CNN encodings is not supported; could add but not very useful perhaps

        phoneLabelsData = None
        if args.path_phone_data:
            phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data)
            
            if not args.CTCphones:
                print(f"Running phone separability with aligned phones")
            else:
                print(f"Running phone separability with CTC loss")

            def constructPhoneCriterionAndOptimizer():
                if not args.CTCphones:
                    # print(f"Running phone separability with aligned phones")
                    phone_criterion = cr.PhoneCriterion(dim_features,
                                                nPhonesInData, args.phone_get_encoded,
                                                nLayers=args.linsep_net_layers)
                else:
                    # print(f"Running phone separability with CTC loss")
                    phone_criterion = cr.CTCPhoneCriterion(dim_features,
                                                    nPhonesInData, args.phone_get_encoded,
                                                    nLayers=args.linsep_net_layers)
                phone_criterion.cuda()
                phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU))

                # Optimizer
                phone_g_params = list(phone_criterion.parameters())

                phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)
                
                return phone_criterion, phone_optimizer
        
        if args.speaker_sep:
            print(f"Running speaker separability")

            def constructSpeakerCriterionAndOptimizer():
                speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers),
                                                        nLayers=args.linsep_net_layers)
                speaker_criterion.cuda()
                speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU))

                speaker_g_params = list(speaker_criterion.parameters())

                speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)

                return speaker_criterion, speaker_optimizer

        linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
                                phoneLabelsData, len(speakers))
        linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
                                    phoneLabelsData, len(speakers))

        linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True,
                                        numWorkers=0)

        linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False,
                                    numWorkers=0)

        def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch):
            log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch))
            if not os.path.exists(log_path_for_epoch):
                os.makedirs(log_path_for_epoch)
            log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/")
            log_path_speaker = os.path.join(log_path_for_epoch, "speaker/")
            if not os.path.exists(log_path_phoneme):
                os.makedirs(log_path_phoneme)
            if not os.path.exists(log_path_speaker):
                os.makedirs(log_path_speaker)
            if args.linsep_checkpoint_dir:
                checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch))
                checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/")
                checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/")
                if not os.path.exists(checkpoint_path_phoneme):
                    os.makedirs(checkpoint_path_phoneme)
                if not os.path.exists(checkpoint_path_speaker):
                    os.makedirs(checkpoint_path_speaker)
            locLogsPhone = {}
            locLogsSpeaker = {}
            if args.path_phone_data:
                phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer()
                locLogsPhone = linsep.trainLinsepClassification(
                    cpcMdl,
                    phone_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    phone_optimizer,
                    log_path_phoneme,
                    args.linsep_task_logging_step,
                    checkpoint_path_phoneme,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'phone')
                del phone_criterion
                del phone_optimizer
            if args.speaker_sep:
                speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer()
                locLogsSpeaker = linsep.trainLinsepClassification(
                    cpcMdl,
                    speaker_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    speaker_optimizer,
                    log_path_speaker,
                    args.linsep_task_logging_step,
                    checkpoint_path_speaker,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'speaker')
                del speaker_criterion
                del speaker_optimizer

            locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()}
            locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()}
            return {**locLogsPhone, **locLogsSpeaker}

        linsepClassificationTaskConfig = (args.linsep_classif_each_epochs,
                                            runLinsepClassificationTraining)

    else:
        linsepClassificationTaskConfig = (None, None)

    if not args.onlyCapture and not args.only_classif_metric:
        run(trainDataset,
            valDataset,
            (captureDataset, captureOptions, captureSetStatsCollector),
            linsepClassificationTaskConfig,
            batchSize,
            args.samplingType,
            cpcModel,
            cpcCriterion,
            args.nEpoch,
            args.pathCheckpoint,
            optimizer,
            scheduler,
            logs)
    if args.onlyCapture:  
    # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        onlyCapture(
            (captureDataset, captureOptions, captureSetStatsCollector),
            batchSize,
            cpcModel,
            cpcCriterion,
            logs)
    if args.only_classif_metric:
    # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        trainedEpoch = len(logs["epoch"]) - 1
        # runPhonemeClassificationTraining created above if args.supervised_classif_metric
        runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
Beispiel #4
0
    args = parser.parse_args()
    print(sys.argv)

    if args.command == 'per':
        args = get_PER_args(args)

    # Output Directory
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    name = f"_{args.name}" if args.command == "per" else ""
    pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt')
    tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE)
    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())

    phoneLabels, nPhones = parseSeqLabels(args.pathPhone)

    inSeqs, _ = findAllSeqs(args.pathDB, extension=args.file_extension)
    # Datasets
    if args.command == 'train' and args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, inSeqs)
    else:
        seqTrain = inSeqs

    if args.pathVal is None and args.command == 'train':
        random.shuffle(seqTrain)
        sizeTrain = int(0.9 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
    elif args.pathVal is not None:
        seqVal = filterSeqs(args.pathVal, inSeqs)
    else:
Beispiel #5
0
def main(args):
    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args,
                        locArgs,
                        forbiddenAttr={
                            "nGPU", "pathCheckpoint", "debug", "restart",
                            "world_size", "n_nodes", "node_id",
                            "n_gpu_per_node", "max_size_loaded"
                        })
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
    # Datasets
    if args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, seqNames)
    else:
        seqTrain = seqNames

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.99 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
        print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
    else:
        seqVal = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seqTrain = seqTrain[-1000:]
        seqVal = seqVal[-100:]

    phoneLabels, nPhones = None, None
    if args.supervised and args.pathPhone is not None:
        print("Loading the phone labels at " + args.pathPhone)
        phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
        print(f"{nPhones} phones found")

    print("")
    print(f'Loading audio data at {args.pathDB}')
    print("Loading the training dataset")
    trainDataset = AudioBatchData(args.pathDB,
                                  args.sizeWindow,
                                  seqTrain,
                                  phoneLabels,
                                  len(speakers),
                                  nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    print("Training dataset loaded")
    print("")

    print("Loading the validation dataset")
    valDataset = AudioBatchData(args.pathDB,
                                args.sizeWindow,
                                seqVal,
                                phoneLabels,
                                len(speakers),
                                nProcessLoader=args.n_process_loader)
    print("Validation dataset loaded")
    print("")

    if args.load is not None:
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load)

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],
                                     cpcModel.gEncoder.DOWNSAMPLING,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()

    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params,
                                 lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                n_epoch, epoch),
            last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        for i in range(len(logs["epoch"])):
            scheduler.step()

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()

    run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel,
        cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler,
        logs)