Beispiel #1
0
def CPC_audio(pretrained=False, **kwargs):
    """
    Contrast predictive learning model for audio data
    pretrained: if True, load a model trained on libri-light 60k
    (https://arxiv.org/abs/1912.07875)
    **kwargs : see cpc/cpc_default_config to get the list of possible arguments
    """
    locArgs = get_default_cpc_config()
    if pretrained:
        checkpoint_url = 'https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt'
        checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url,
                                                        progress=False)
        loadArgs(locArgs, argparse.Namespace(**checkpoint["config"]))
    else:
        args = argparse.Namespace(**kwargs)
        loadArgs(locArgs, args)
    encoderNet = getEncoder(locArgs)
    arNet = getAR(locArgs)
    model = cpcmodel(encoderNet, arNet)
    if pretrained:
        model.load_state_dict(checkpoint["weights"], strict=False)
    return model
Beispiel #2
0
    downsampling_factor = 160
    if args.pathCheckpoint == 'ID':
        downsampling_factor = 1
        feature_maker = IDModule()
        hiddenGar = args.in_dim
    elif args.pathCheckpoint == "hub":
        print("Load model from pytorch hub")
        locArgs = get_default_cpc_config()
        checkpoint_url = 'https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt'
        checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url,
                                                        progress=False,
                                                        map_location="cuda:0")
        loadArgs(locArgs, argparse.Namespace(**checkpoint["config"]))
        encoderNet = getEncoder(locArgs)
        arNet = getAR(locArgs)
        model = CPCModel(encoderNet, arNet)
        if not args.no_pretraining:
            model.load_state_dict(checkpoint["weights"], strict=False)
        feature_maker = model
        hiddenGar = locArgs.hiddenGar
        print(feature_maker, hiddenGar)
        print()
    else:
        feature_maker, hiddenGar, _ = loadModel(
            [args.pathCheckpoint], loadStateDict=not args.no_pretraining)
    feature_maker.cuda()
    feature_maker = torch.nn.DataParallel(feature_maker)

    phone_criterion = CTCphone_criterion(hiddenGar,
                                         nPhones,
Beispiel #3
0
 def testbuildNoAR(self):
     from cpc.model import CPCAR
     self.default_args.arMode = 'RNN'
     test_ar = fl.getAR(self.default_args)
     ok_(isinstance(test_ar, CPCAR))
     ok_(isinstance(test_ar.baseNet, torch.nn.RNN))
Beispiel #4
0
    def testbuildNoAR(self):
        from cpc.model import NoAr
        self.default_args.arMode = 'no_ar'

        test_ar = fl.getAR(self.default_args)
        ok_(isinstance(test_ar, NoAr))
Beispiel #5
0
def main(args):

    # import ptvsd
    # ptvsd.enable_attach(('0.0.0.0', 7309))
    # print("Attach debugger now")
    # ptvsd.wait_for_attach()

    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    os.makedirs(args.pathCheckpoint, exist_ok=True)
    if not args.onlyCapture and not args.only_classif_metric:
        json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt'))
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args, locArgs,
                        forbiddenAttr={"nGPU", "pathCheckpoint",
                                       "debug", "restart", "world_size",
                                       "n_nodes", "node_id", "n_gpu_per_node",
                                       "max_size_loaded"})
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if not args.onlyCapture or args.only_classif_metric:
        print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
        # Datasets
        if args.pathTrain is not None:
            seqTrain = filterSeqs(args.pathTrain, seqNames)
        else:
            seqTrain = seqNames

        if args.pathVal is None:
            random.shuffle(seqTrain)
            sizeTrain = int(0.99 * len(seqTrain))
            seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
            print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
        else:
            seqVal = filterSeqs(args.pathVal, seqNames)

    if args.pathCaptureDS is not None:
        assert args.pathCaptureSave is not None
        whatToSave = []
        if args.captureEverything:
            whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred']
            if args.path_phone_data:
                whatToSave.append('phone_align')
            if args.CPCCTC:
                whatToSave.append('cpcctc_align')
                whatToSave.append('cpcctc_log_scores')
        else:
            for argVal, name in zip([args.captureConvRepr, 
                                    args.captureCtxRepr, 
                                    args.captureSpeakerAlign, 
                                    args.capturePhoneAlign,
                                    args.capturePred,
                                    args.captureCPCCTCalign,
                                    args.captureCPCCTClogScores], 
                                    ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']):
                if argVal:
                    whatToSave.append(name)
        ###assert len(whatToSave) > 0
        captureOptions = {
            'path': args.pathCaptureSave,
            'eachEpochs': args.captureEachEpochs,
            'what': whatToSave
        }
        seqCapture = filterSeqs(args.pathCaptureDS, seqNames, 
                                percentage=args.captureDSfreq, totalNum=args.captureDStotNr)
        print(f'Capture files: {len(seqCapture)}')
    else:
        seqCapture = None
        captureOptions = None

    if not args.onlyCapture:
        if args.debug:
            seqTrain = seqTrain[-1000:]
            seqVal = seqVal[-100:]

        phoneLabels, nPhones = None, None
        if args.supervised and args.pathPhone is not None:
            print("Loading the phone labels at " + args.pathPhone)
            phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
            print(f"{nPhones} phones found")

        print("")
        print(f'Loading audio data at {args.pathDB}')
        print("Loading the training dataset")
        trainDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqTrain,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader,
                                    MAX_SIZE_LOADED=args.max_size_loaded)
        print("Training dataset loaded")
        print("")

        print("Loading the validation dataset")
        valDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqVal,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Validation dataset loaded")
        print("")
    else:
        phoneLabels, nPhones = None, None
        trainDataset = None
        valDataset = None

    if seqCapture is not None:

        if args.path_phone_data:
            print("Loading the phone labels at " + args.path_phone_data)
            phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data)
        else:
            assert not args.capturePhoneAlign
            phoneLabelsForCapture = None
            
        print("Loading the capture dataset")
        captureDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqCapture,
                                    phoneLabelsForCapture,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Capture dataset loaded")
        print("")

        if args.captureSetStats:
            captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats)
        else:
            captureSetStatsCollector = None
    else:
        captureDataset = None
        captureSetStatsCollector = None

    if args.load is not None:
        if args.gru_level is not None and args.gru_level > 0:
            updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
        else:
            updateConfig = None


        # loadBestNotLast = args.onlyCapture or args.only_classif_metric
        # could use this option for loading best state when not running actual training
        # but relying on CPC internal acc isn't very reliable
        # [!] caution - because of how they capture checkpoints,
        #     they capture "best in this part of training" as "best" (apart from capturing current state)
        #     so if best is in epoch 100 and training is paused and resumed from checkpoint
        #     in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best
        #     (but this is internal-CPC-score best anyway, which is quite vague)
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig)
        CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder            

        if args.gru_level is not None and args.gru_level > 0:
            # Keep hidden units at LSTM layers on sequential batches
            if args.nullspace:
                cpcModel.cpc.gAR.keepHidden = True
            else:
                cpcModel.gAR.keepHidden = True

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

        CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput()

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING
    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],  downsampling,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, downsampling,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()
    
    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params, lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer and not args.onlyCapture and not args.only_classif_metric:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")
        with open(args.pathCheckpoint + "_args.json", 'w') as file:
            json.dump(vars(args), file, indent=2)

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                           lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                                                               n_epoch, epoch),
                                                           last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        print(f'Redoing {len(logs["epoch"])} scheduler steps')
        for i in range(len(logs["epoch"])):
            scheduler.step()

    print("cpcModel", cpcModel)
    print("cpcCriterion", cpcCriterion)

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()
    
    if args.supervised_classif_metric:

        linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU

        dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar
        dim_ctx_features = CPChiddenGar  # for speakers using CNN encodings is not supported; could add but not very useful perhaps

        phoneLabelsData = None
        if args.path_phone_data:
            phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data)
            
            if not args.CTCphones:
                print(f"Running phone separability with aligned phones")
            else:
                print(f"Running phone separability with CTC loss")

            def constructPhoneCriterionAndOptimizer():
                if not args.CTCphones:
                    # print(f"Running phone separability with aligned phones")
                    phone_criterion = cr.PhoneCriterion(dim_features,
                                                nPhonesInData, args.phone_get_encoded,
                                                nLayers=args.linsep_net_layers)
                else:
                    # print(f"Running phone separability with CTC loss")
                    phone_criterion = cr.CTCPhoneCriterion(dim_features,
                                                    nPhonesInData, args.phone_get_encoded,
                                                    nLayers=args.linsep_net_layers)
                phone_criterion.cuda()
                phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU))

                # Optimizer
                phone_g_params = list(phone_criterion.parameters())

                phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)
                
                return phone_criterion, phone_optimizer
        
        if args.speaker_sep:
            print(f"Running speaker separability")

            def constructSpeakerCriterionAndOptimizer():
                speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers),
                                                        nLayers=args.linsep_net_layers)
                speaker_criterion.cuda()
                speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU))

                speaker_g_params = list(speaker_criterion.parameters())

                speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)

                return speaker_criterion, speaker_optimizer

        linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
                                phoneLabelsData, len(speakers))
        linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
                                    phoneLabelsData, len(speakers))

        linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True,
                                        numWorkers=0)

        linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False,
                                    numWorkers=0)

        def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch):
            log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch))
            if not os.path.exists(log_path_for_epoch):
                os.makedirs(log_path_for_epoch)
            log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/")
            log_path_speaker = os.path.join(log_path_for_epoch, "speaker/")
            if not os.path.exists(log_path_phoneme):
                os.makedirs(log_path_phoneme)
            if not os.path.exists(log_path_speaker):
                os.makedirs(log_path_speaker)
            if args.linsep_checkpoint_dir:
                checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch))
                checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/")
                checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/")
                if not os.path.exists(checkpoint_path_phoneme):
                    os.makedirs(checkpoint_path_phoneme)
                if not os.path.exists(checkpoint_path_speaker):
                    os.makedirs(checkpoint_path_speaker)
            locLogsPhone = {}
            locLogsSpeaker = {}
            if args.path_phone_data:
                phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer()
                locLogsPhone = linsep.trainLinsepClassification(
                    cpcMdl,
                    phone_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    phone_optimizer,
                    log_path_phoneme,
                    args.linsep_task_logging_step,
                    checkpoint_path_phoneme,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'phone')
                del phone_criterion
                del phone_optimizer
            if args.speaker_sep:
                speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer()
                locLogsSpeaker = linsep.trainLinsepClassification(
                    cpcMdl,
                    speaker_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    speaker_optimizer,
                    log_path_speaker,
                    args.linsep_task_logging_step,
                    checkpoint_path_speaker,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'speaker')
                del speaker_criterion
                del speaker_optimizer

            locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()}
            locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()}
            return {**locLogsPhone, **locLogsSpeaker}

        linsepClassificationTaskConfig = (args.linsep_classif_each_epochs,
                                            runLinsepClassificationTraining)

    else:
        linsepClassificationTaskConfig = (None, None)

    if not args.onlyCapture and not args.only_classif_metric:
        run(trainDataset,
            valDataset,
            (captureDataset, captureOptions, captureSetStatsCollector),
            linsepClassificationTaskConfig,
            batchSize,
            args.samplingType,
            cpcModel,
            cpcCriterion,
            args.nEpoch,
            args.pathCheckpoint,
            optimizer,
            scheduler,
            logs)
    if args.onlyCapture:  
    # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        onlyCapture(
            (captureDataset, captureOptions, captureSetStatsCollector),
            batchSize,
            cpcModel,
            cpcCriterion,
            logs)
    if args.only_classif_metric:
    # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        trainedEpoch = len(logs["epoch"]) - 1
        # runPhonemeClassificationTraining created above if args.supervised_classif_metric
        runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
Beispiel #6
0
def main(args):
    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args,
                        locArgs,
                        forbiddenAttr={
                            "nGPU", "pathCheckpoint", "debug", "restart",
                            "world_size", "n_nodes", "node_id",
                            "n_gpu_per_node", "max_size_loaded"
                        })
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
    # Datasets
    if args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, seqNames)
    else:
        seqTrain = seqNames

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.99 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
        print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
    else:
        seqVal = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seqTrain = seqTrain[-1000:]
        seqVal = seqVal[-100:]

    phoneLabels, nPhones = None, None
    if args.supervised and args.pathPhone is not None:
        print("Loading the phone labels at " + args.pathPhone)
        phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
        print(f"{nPhones} phones found")

    print("")
    print(f'Loading audio data at {args.pathDB}')
    print("Loading the training dataset")
    trainDataset = AudioBatchData(args.pathDB,
                                  args.sizeWindow,
                                  seqTrain,
                                  phoneLabels,
                                  len(speakers),
                                  nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    print("Training dataset loaded")
    print("")

    print("Loading the validation dataset")
    valDataset = AudioBatchData(args.pathDB,
                                args.sizeWindow,
                                seqVal,
                                phoneLabels,
                                len(speakers),
                                nProcessLoader=args.n_process_loader)
    print("Validation dataset loaded")
    print("")

    if args.load is not None:
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load)

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],
                                     cpcModel.gEncoder.DOWNSAMPLING,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()

    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params,
                                 lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                n_epoch, epoch),
            last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        for i in range(len(logs["epoch"])):
            scheduler.step()

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()

    run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel,
        cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler,
        logs)