예제 #1
0
def loadCPCFeatureMaker(pathCheckpoint,
                        gru_level=-1,
                        get_encoded=False,
                        keep_hidden=True,
                        load_nullspace=False):
    """
    Load CPC Feature Maker from CPC checkpoint file.
    """
    # Set LSTM level
    if gru_level is not None and gru_level > 0:
        updateConfig = argparse.Namespace(nLevelsGRU=gru_level)
    else:
        updateConfig = None

    # Load CPC model
    model, nHiddenGar, nHiddenEncoder = loadModel(
        [pathCheckpoint],
        updateConfig=updateConfig,
        load_nullspace=load_nullspace)

    # Keep hidden units at LSTM layers on sequential batches
    if load_nullspace:
        model.cpc.gAR.keepHidden = keep_hidden
    else:
        model.gAR.keepHidden = keep_hidden

    # Build CPC Feature Maker from CPC model
    featureMaker = FeatureModule(model, get_encoded=get_encoded)

    return featureMaker
예제 #2
0
def main(argv):

    args = parse_args(argv)

    if args.load == 'from_checkpoint':
        # Checkpoint
        model = loadModel([args.path_checkpoint])[0]
        model.gAR.keepHidden = True
        # Feature maker
        feature_maker = FeatureModule(model, args.get_encoded).cuda().eval()

        def feature_function(x): return buildFeature(feature_maker, x,
                                                     seqNorm=args.seq_norm,
                                                     strict=args.strict,
                                                     maxSizeSeq=args.max_size_seq)
    elif args.load == 'from_pre_computed':
        def feature_function(x): return torch.load(x, 'cpu')

    # Modes
    if args.mode == 'all':
        modes = ["within", "across"]
    else:
        modes = [args.mode]

    distance_mode = 'cosine'

    step_feature = 1 / args.feature_size

    # Get the list of sequences
    seq_list, _ = findAllSeqs(args.path_dataset, extension=args.file_extension)
    seq_list = [(str(Path(x).stem), str(Path(args.path_dataset) / x))
                for (_, x) in seq_list]

    if args.debug:
        seq_list = seq_list[:1000]

    scores = ABX(feature_function, args.path_item_file,
                 seq_list, distance_mode,
                 step_feature, modes,
                 cuda=args.cuda,
                 seq_norm=args.seq_norm,
                 max_x_across=args.max_x_across,
                 max_size_group=args.max_size_group)

    out_dir = Path(args.path_checkpoint).parent if args.out is None \
        else Path(args.out)
    out_dir.mkdir(exist_ok=True)

    path_score = out_dir / 'ABX_scores.json'
    with open(path_score, 'w') as file:
        json.dump(scores, file, indent=2)

    path_args = out_dir / 'ABX_args.json'
    with open(path_args, 'w') as file:
        json.dump(vars(args), file, indent=2)
예제 #3
0
    randomOffset=False,
    numWorkers=0)

db_test = LibriSelectionDataset(sizeWindow=20480,
                                db_wav_root=DB_WAV_ROOT,
                                fps_list=TS_LIST_PATH,
                                label_path=LABEL_PATH,
                                n_process_loader=8,
                                MAX_SIZE_LOADED=400000000)
test_loader = db_test.getDataLoader(
    batchSize=128,
    type='sequential',  #'sequential',
    randomOffset=False,
    numWorkers=0)
"""Load model: c, z, label = feat_gen(raw_audio, label)"""
feat_gen, d_c, d_z = fl.loadModel([CPC_CHECKPOINT_PATH], loadStateDict=True)

if nGPU > 0:
    feat_gen = feat_gen.cuda()
    feat_gen = torch.nn.DataParallel(feat_gen, device_ids=range(nGPU))
feat_gen.optimize = False
feat_gen.eval()
for g in feat_gen.parameters():
    g.requires_grad = False
"""Create classifier: clf()"""
d_feat = int('c' in SEL_FEAT) * d_c + int('z' in SEL_FEAT) * d_z
n_classes = len(db_train.speakers)
#clf = MLP(d_feat, 2048, n_classes)
clf = SpeakerClf(d_feat, n_classes)
if nGPU > 0:
    clf = clf.cuda()
예제 #4
0
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if args.model == "cpc":
        def loadCPCFeatureMaker(pathCheckpoint, gru_level=-1, get_encoded=False, keep_hidden=True):
            """
            Load CPC Feature Maker from CPC checkpoint file.
            """
            # Set LSTM level
            if gru_level is not None and gru_level > 0:
                updateConfig = argparse.Namespace(nLevelsGRU=gru_level)
            else:
                updateConfig = None

            # Load CPC model
            model, nHiddenGar, nHiddenEncoder = fl.loadModel(pathCheckpoint, updateConfig=updateConfig)
            
            # Keep hidden units at LSTM layers on sequential batches
            model.gAR.keepHidden = keep_hidden

            # Build CPC Feature Maker from CPC model
            #featureMaker = fl.FeatureModule(model, get_encoded=get_encoded)

            #return featureMaker
            return model, nHiddenGar, nHiddenEncoder

        if args.gru_level is not None and args.gru_level > 0:
            model, hidden_gar, hidden_encoder = loadCPCFeatureMaker(args.load, gru_level=args.gru_level)
        else:
            model, hidden_gar, hidden_encoder = fl.loadModel(args.load,
                                                     loadStateDict=not args.no_pretraining)

        dim_features = hidden_encoder if args.get_encoded else hidden_gar
    else:
        sys.path.append(os.path.abspath(args.path_fairseq))
        from fairseq import checkpoint_utils

        def loadCheckpoint(path_checkpoint, path_data):
            """
            Load lstm_lm model from checkpoint.
            """
            # Set up the args Namespace
            model_args = argparse.Namespace(
                task="language_modeling",
                output_dictionary_size=-1,
                data=path_data,
                path=path_checkpoint
                )
            
            # Load model
            models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path])
            model = models[0]
            return model

        model = loadCheckpoint(args.load[0], args.pathDB)
        dim_features = 768

    dim_inter = args.dim_inter
    # Now the criterion

    if args.mode == "phonemes_nullspace" or args.mode == "speakers_nullspace":
        speakers_factorized = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        speakers_factorized.load_state_dict(torch.load(args.path_speakers_factorized)["cpcCriterion"])
        for param in speakers_factorized.parameters():
            param.requires_grad = False

        def my_nullspace(At, rcond=None):
            ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True)
            vht=vht.T        
            Mt, Nt = ut.shape[0], vht.shape[1] 
            if rcond is None:
                rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt)
            tolt = torch.max(st) * rcondt
            numt= torch.sum(st > tolt, dtype=int)
            nullspace = vht[numt:,:].T.cpu().conj()
            # nullspace.backward(torch.ones_like(nullspace),retain_graph=True)
            return nullspace

        dim_features = dim_features - dim_inter
        nullspace = my_nullspace(speakers_factorized.linearSpeakerClassifier[0].weight)
        model = CPCModelNullspace(model, nullspace)

    phone_labels = None
    if args.pathPhone is not None:

        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        label_key = 'phone'

        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features,
                                          n_phones, args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features,
                                             n_phones, args.get_encoded)
    else:
        label_key = 'speaker'
        print(f"Running speaker separability")
        if args.mode == "speakers_factorized":
            criterion = cr.SpeakerDoubleCriterion(dim_features, dim_inter, len(speakers))
        else:
            criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers), nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers), nProcessLoader=args.n_process_loader)

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size, "uniform", True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size, 'sequential', False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params, lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    if args.centerpushFile:
        clustersFileExt = args.centerpushFile.split('.')[-1]
        assert clustersFileExt in ('pt', 'npy', 'txt')
        if clustersFileExt == 'npy':
            centers = np.load(args.centerpushFile)
        elif clustersFileExt == 'txt':
            centers = np.genfromtxt(args.centerpushFile)
        elif clustersFileExt == 'pt':  # assuming it's a checkpoint
            centers = torch.load(args.centerpushFile, map_location=torch.device('cpu'))['state_dict']['Ck']
            centers = torch.reshape(centers, centers.shape[1:]).numpy()
        centers = torch.tensor(centers).cuda()
        centerpushSettings = (centers, args.centerpushDeg)
    else:
        centerpushSettings = None

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint, label_key=label_key, centerpushSettings=centerpushSettings)
def main(argv):
    # Args parser
    args = parseArgs(argv)
    
    print("=============================================================")
    print(f"Quantizing data from {args.pathDB}")
    print("=============================================================")

    # Check if directory exists
    if not os.path.exists(args.pathOutput):
        print("")
        print(f"Creating the output directory at {args.pathOutput}")
        Path(args.pathOutput).mkdir(parents=True, exist_ok=True)

    # Get splits
    if args.split:
        assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \
            "SPLIT must be under the form idxSplit-numSplits (numSplits >= idxSplit >= 1), eg. --split 1-20"
        idx_split, num_splits = args.split.split("-")
        idx_split = int(idx_split)
        num_splits = int(num_splits)

    # Find all sequences
    print("")
    print(f"Looking for all {args.file_extension} files in {args.pathDB} with speakerLevel {args.recursionLevel}")
    seqNames, speakers = findAllSeqs(args.pathDB,
                                 speaker_level=args.recursionLevel,
                                 extension=args.file_extension,
                                 loadCache=True)

    if args.pathSeq:
        with open(args.pathSeq, 'r') as f:
            seqs = set([x.strip() for x in f])

        filtered = []
        for s in seqNames:
            if s[1].split('/')[-1].split('.')[0] in seqs:
                filtered.append(s)
        seqNames = filtered

    print(f"Done! Found {len(seqNames)} files and {len(speakers)} speakers!")
    if args.separate_speaker:
        seqNames_by_speaker = {}
        for seq in seqNames:
            speaker = seq[1].split("/")[args.recursionLevel-1]
            if speaker not in seqNames_by_speaker:
                seqNames_by_speaker[speaker] = []
            seqNames_by_speaker[speaker].append(seq)

    # Check if output file exists
    if not args.split:
        nameOutput = "quantized_outputs.txt"
    else:
        nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt"
    if args.separate_speaker is False:
        outputFile = os.path.join(args.pathOutput, nameOutput)
        assert not os.path.exists(outputFile), \
            f"Output file {outputFile} already exists !!!"
    
    # Get splits
    if args.split:
        startIdx = len(seqNames) // num_splits * (idx_split-1)
        if idx_split == num_splits:
            endIdx = len(seqNames)
        else:
            endIdx = min(len(seqNames) // num_splits * idx_split, len(seqNames))
        seqNames = seqNames[startIdx:endIdx]
        print("")
        print(f"Quantizing split {idx_split} out of {num_splits} splits, with {len(seqNames)} files (idx in range({startIdx}, {endIdx})).")

    # Debug mode
    if args.debug:
        nsamples=20
        print("")
        print(f"Debug mode activated, only load {nsamples} samples!")
        # shuffle(seqNames)
        seqNames = seqNames[:nsamples]

    # Load Clustering args
    assert args.pathCheckpoint[-3:] == ".pt"
    if os.path.exists(args.pathCheckpoint[:-3] + "_args.json"):
        pathConfig = args.pathCheckpoint[:-3] + "_args.json"
    elif os.path.exists(os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")):
        pathConfig = os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")
    else:
        assert False, \
            f"Args file not found in the directory {os.path.dirname(args.pathCheckpoint)}"
    clustering_args = readArgs(pathConfig)
    print("")
    print(f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}")
    print('-' * 50)

    # Load CluterModule
    clusterModule = loadClusterModule(args.pathCheckpoint, norm_vec_len=args.norm_vec_len)
    clusterModule.cuda()

    # Load FeatureMaker
    print("")
    print("Loading CPC FeatureMaker")
    if 'level_gru' in vars(clustering_args) and clustering_args.level_gru is not None:
        updateConfig = argparse.Namespace(nLevelsGRU=clustering_args.level_gru)
    else:
        updateConfig = None
    model = loadModel([clustering_args.pathCheckpoint], updateConfig=updateConfig)[0]
    ## If we don't apply batch implementation, we can set LSTM model to keep hidden units
    ## making the quality of the quantized units better
    if args.nobatch:
        model.gAR.keepHidden = True
    featureMaker = FeatureModule(model, clustering_args.encoder_layer)
    if clustering_args.dimReduction is not None:
        dimRed = loadDimReduction(clustering_args.dimReduction, clustering_args.centroidLimits)
        featureMaker = torch.nn.Sequential(featureMaker, dimRed)
    if not clustering_args.train_mode:
        featureMaker.eval()
    featureMaker.cuda()
    def feature_function(x): 
        if args.nobatch is False:
            res0 = buildFeature_batch(featureMaker, x,
                                                    seqNorm=False,
                                                    strict=args.strict,
                                                    maxSizeSeq=args.max_size_seq,
                                                    batch_size=args.batch_size)
            if args.norm_vec_len:
                # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end
                res0Lengths = torch.sqrt((res0*res0).sum(2))
                res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1)
            return res0
        else:
            res0 = buildFeature(featureMaker, x,
                                seqNorm=False,
                                strict=args.strict)
            if args.norm_vec_len:
                # [!] we actually used CPC_audio/scripts/quantize_audio.py for that in the end
                res0Lengths = torch.sqrt((res0*res0).sum(2))
                res0 = res0 / res0Lengths.view(*(res0Lengths.shape), 1)
            return res0
    print("CPC FeatureMaker loaded!")
    
    # Quantization of files
    print("")
    print(f"Quantizing audio files...")
    seqQuantLines = []
    bar = progressbar.ProgressBar(maxval=len(seqNames))
    bar.start()
    start_time = time()
    for index, vals in enumerate(seqNames):
        bar.update(index)

        file_path = vals[1]
        file_path = os.path.join(args.pathDB, file_path)

        # Get features & quantizing
        cFeatures = feature_function(file_path).cuda()

        nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1)

        cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1))

        if len(vals) > 2 and int(vals[-1]) > 9400000: # Librilight, to avoid OOM
            clusterModule = clusterModule.cpu()
            cFeatures = cFeatures.cpu()
            qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
            clusterModule = clusterModule.cuda()
        else:
            qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
        qFeatures = qFeatures[0].detach().cpu().numpy()

        # Transform to quantized line
        quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)])
        seqQuantLines.append(quantLine)

    bar.finish()
    print(f"...done {len(seqQuantLines)} files in {time()-start_time} seconds.")

    # Saving outputs
    print("")
    print(f"Saving outputs to {outputFile}")
    outLines = []
    for vals, quantln in zip(seqNames, seqQuantLines):
        file_path = vals[1]
        file_name = os.path.splitext(os.path.basename(file_path))[0]
        outLines.append("\t".join([file_name, quantln]))
    with open(outputFile, "w") as f:
        f.write("\n".join(outLines))
예제 #6
0
    nGPUs = torch.cuda.device_count()
    batchSize = args.batchSizeGPU * nGPUs
    trainLoader = dataset.getDataLoader(batchSize,
                                        "uniform",
                                        False,
                                        numWorkers=0)
    print(f"Length of dataLoader: {len(trainLoader)}")
    print("")

    if args.level_gru is None:
        updateConfig = None
    else:
        updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru)

    model = loadModel([args.pathCheckpoint],
                      updateConfig=updateConfig,
                      load_nullspace=args.nullspace)[0]
    #model = loadModel([args.pathCheckpoint])[0]#, updateConfig=updateConfig)[0]

    featureMaker = FeatureModule(model, args.encoder_layer)
    print("Checkpoint loaded!")
    print("")

    if not args.train_mode:
        featureMaker.eval()
    featureMaker.cuda()

    # Check if dir exists
    if not os.path.exists(os.path.dirname(
            args.pathOutput)) and os.path.dirname(args.pathOutput):
        Path(os.path.dirname(args.pathOutput)).mkdir(parents=True,
예제 #7
0
def main(argv):

    args = parse_args(argv)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    load_criterion = False

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    model, hidden_gar, hidden_encoder = fl.loadModel(
        args.load, loadStateDict=not args.no_pretraining)
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(args.nGPU))

    dim_features = hidden_encoder if args.get_encoded else hidden_gar

    # Now the criterion
    phone_labels = None
    if args.pathPhone is not None:
        phone_labels, n_phones = parseSeqLabels(args.pathPhone)
        if not args.CTC:
            print(f"Running phone separability with aligned phones")
            criterion = cr.PhoneCriterion(dim_features, n_phones,
                                          args.get_encoded)
        else:
            print(f"Running phone separability with CTC loss")
            criterion = cr.CTCPhoneCriterion(dim_features, n_phones,
                                             args.get_encoded)
    else:
        print(f"Running speaker separability")
        criterion = cr.SpeakerCriterion(dim_features, len(speakers))
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion, device_ids=range(args.nGPU))

    # Dataset
    seq_train = filterSeqs(args.pathTrain, seqNames)
    seq_val = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seq_train = seq_train[:1000]
        seq_val = seq_val[:100]

    db_train = AudioBatchData(args.pathDB, args.size_window, seq_train,
                              phone_labels, len(speakers))
    db_val = AudioBatchData(args.pathDB, args.size_window, seq_val,
                            phone_labels, len(speakers))

    batch_size = args.batchSizeGPU * args.nGPU

    train_loader = db_train.getDataLoader(batch_size,
                                          "uniform",
                                          True,
                                          numWorkers=0)

    val_loader = db_val.getDataLoader(batch_size,
                                      'sequential',
                                      False,
                                      numWorkers=0)

    # Optimizer
    g_params = list(criterion.parameters())
    model.optimize = False
    model.eval()
    if args.unfrozen:
        print("Working in full fine-tune mode")
        g_params += list(model.parameters())
        model.optimize = True
    else:
        print("Working with frozen features")
        for g in model.parameters():
            g.requires_grad = False

    optimizer = torch.optim.Adam(g_params,
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    # Checkpoint directory
    args.pathCheckpoint = Path(args.pathCheckpoint)
    args.pathCheckpoint.mkdir(exist_ok=True)
    args.pathCheckpoint = str(args.pathCheckpoint / "checkpoint")

    with open(f"{args.pathCheckpoint}_args.json", 'w') as file:
        json.dump(vars(args), file, indent=2)

    run(model, criterion, train_loader, val_loader, optimizer, logs,
        args.n_epoch, args.pathCheckpoint)
예제 #8
0
        checkpoint_url = 'https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt'
        checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url,
                                                        progress=False,
                                                        map_location="cuda:0")
        loadArgs(locArgs, argparse.Namespace(**checkpoint["config"]))
        encoderNet = getEncoder(locArgs)
        arNet = getAR(locArgs)
        model = CPCModel(encoderNet, arNet)
        if not args.no_pretraining:
            model.load_state_dict(checkpoint["weights"], strict=False)
        feature_maker = model
        hiddenGar = locArgs.hiddenGar
        print(feature_maker, hiddenGar)
        print()
    else:
        feature_maker, hiddenGar, _ = loadModel(
            [args.pathCheckpoint], loadStateDict=not args.no_pretraining)
    feature_maker.cuda()
    feature_maker = torch.nn.DataParallel(feature_maker)

    phone_criterion = CTCphone_criterion(hiddenGar,
                                         nPhones,
                                         args.LSTM,
                                         seqNorm=args.seqNorm,
                                         dropout=args.dropout,
                                         reduction=args.loss_reduction)
    phone_criterion.cuda()
    phone_criterion = torch.nn.DataParallel(phone_criterion)

    print(f"Loading the validation dataset at {args.pathDB}")
    datasetVal = SingleSequenceDataset(args.pathDB,
                                       seqVal,
    args = parser.parse_args()

    if not os.path.isdir(args.pathOut):
        os.mkdir(args.pathOut)

    with open(os.path.join(os.path.dirname(args.pathOut),
                           f"{os.path.basename(args.pathOut)}.json"), 'w') \
            as file:
        json.dump(vars(args), file, indent=2)

    outData = [
        x[1] for x in findAllSeqs(
            args.pathDB, extension=args.extension, loadCache=False)[0]
    ]

    featureMaker = loadModel([args.pathCheckpoint])[0]
    stepSize = featureMaker.gEncoder.DOWNSAMPLING / 16000
    print(f"stepSize : {stepSize}")
    featureMaker = FeatureModule(featureMaker, args.getEncoded)
    featureMaker.collapse = False

    if args.addCriterion:
        criterion, nPhones = loadSupervisedCriterion(args.pathCheckpoint)
        featureMaker = ModelPhoneCombined(featureMaker, criterion, nPhones,
                                          args.oneHot)
    if device == "cuda":
        featureMaker = featureMaker.cuda(device=0)

    if not args.train_mode:
        featureMaker.eval()
예제 #10
0
 def __init__(self, model_path, intermediate_idx=0):
     super().__init__()
     self.model = loadModel([model_path],
                            intermediate_idx=intermediate_idx)[0]
     self.model.gAR.keepHidden = True
예제 #11
0
def main(args):

    # import ptvsd
    # ptvsd.enable_attach(('0.0.0.0', 7309))
    # print("Attach debugger now")
    # ptvsd.wait_for_attach()

    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    os.makedirs(args.pathCheckpoint, exist_ok=True)
    if not args.onlyCapture and not args.only_classif_metric:
        json.dump(vars(args), open(os.path.join(args.pathCheckpoint, 'checkpoint_args.json'), 'wt'))
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args, locArgs,
                        forbiddenAttr={"nGPU", "pathCheckpoint",
                                       "debug", "restart", "world_size",
                                       "n_nodes", "node_id", "n_gpu_per_node",
                                       "max_size_loaded"})
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    if not args.onlyCapture or args.only_classif_metric:
        print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
        # Datasets
        if args.pathTrain is not None:
            seqTrain = filterSeqs(args.pathTrain, seqNames)
        else:
            seqTrain = seqNames

        if args.pathVal is None:
            random.shuffle(seqTrain)
            sizeTrain = int(0.99 * len(seqTrain))
            seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
            print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
        else:
            seqVal = filterSeqs(args.pathVal, seqNames)

    if args.pathCaptureDS is not None:
        assert args.pathCaptureSave is not None
        whatToSave = []
        if args.captureEverything:
            whatToSave = ['conv_repr', 'ctx_repr', 'speaker_align', 'pred']
            if args.path_phone_data:
                whatToSave.append('phone_align')
            if args.CPCCTC:
                whatToSave.append('cpcctc_align')
                whatToSave.append('cpcctc_log_scores')
        else:
            for argVal, name in zip([args.captureConvRepr, 
                                    args.captureCtxRepr, 
                                    args.captureSpeakerAlign, 
                                    args.capturePhoneAlign,
                                    args.capturePred,
                                    args.captureCPCCTCalign,
                                    args.captureCPCCTClogScores], 
                                    ['conv_repr', 'ctx_repr', 'speaker_align', 'phone_align', 'pred', 'cpcctc_align', 'cpcctc_log_scores']):
                if argVal:
                    whatToSave.append(name)
        ###assert len(whatToSave) > 0
        captureOptions = {
            'path': args.pathCaptureSave,
            'eachEpochs': args.captureEachEpochs,
            'what': whatToSave
        }
        seqCapture = filterSeqs(args.pathCaptureDS, seqNames, 
                                percentage=args.captureDSfreq, totalNum=args.captureDStotNr)
        print(f'Capture files: {len(seqCapture)}')
    else:
        seqCapture = None
        captureOptions = None

    if not args.onlyCapture:
        if args.debug:
            seqTrain = seqTrain[-1000:]
            seqVal = seqVal[-100:]

        phoneLabels, nPhones = None, None
        if args.supervised and args.pathPhone is not None:
            print("Loading the phone labels at " + args.pathPhone)
            phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
            print(f"{nPhones} phones found")

        print("")
        print(f'Loading audio data at {args.pathDB}')
        print("Loading the training dataset")
        trainDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqTrain,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader,
                                    MAX_SIZE_LOADED=args.max_size_loaded)
        print("Training dataset loaded")
        print("")

        print("Loading the validation dataset")
        valDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqVal,
                                    phoneLabels,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Validation dataset loaded")
        print("")
    else:
        phoneLabels, nPhones = None, None
        trainDataset = None
        valDataset = None

    if seqCapture is not None:

        if args.path_phone_data:
            print("Loading the phone labels at " + args.path_phone_data)
            phoneLabelsForCapture, _ = parseSeqLabels(args.path_phone_data)
        else:
            assert not args.capturePhoneAlign
            phoneLabelsForCapture = None
            
        print("Loading the capture dataset")
        captureDataset = AudioBatchData(args.pathDB,
                                    args.sizeWindow,
                                    seqCapture,
                                    phoneLabelsForCapture,
                                    len(speakers),
                                    nProcessLoader=args.n_process_loader)
        print("Capture dataset loaded")
        print("")

        if args.captureSetStats:
            captureSetStatsCollector = statutil.constructStatCollectorFromSpecs(args.captureSetStats)
        else:
            captureSetStatsCollector = None
    else:
        captureDataset = None
        captureSetStatsCollector = None

    if args.load is not None:
        if args.gru_level is not None and args.gru_level > 0:
            updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
        else:
            updateConfig = None


        # loadBestNotLast = args.onlyCapture or args.only_classif_metric
        # could use this option for loading best state when not running actual training
        # but relying on CPC internal acc isn't very reliable
        # [!] caution - because of how they capture checkpoints,
        #     they capture "best in this part of training" as "best" (apart from capturing current state)
        #     so if best is in epoch 100 and training is paused and resumed from checkpoint
        #     in epoch 150, checkpoint from epoch 200 has "best from epoch 150" saved as globally best
        #     (but this is internal-CPC-score best anyway, which is quite vague)
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load, load_nullspace=args.nullspace, updateConfig=updateConfig)
        CPChiddenGar, CPChiddenEncoder = args.hiddenGar, args.hiddenEncoder            

        if args.gru_level is not None and args.gru_level > 0:
            # Keep hidden units at LSTM layers on sequential batches
            if args.nullspace:
                cpcModel.cpc.gAR.keepHidden = True
            else:
                cpcModel.gAR.keepHidden = True

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

        CPChiddenGar, CPChiddenEncoder = cpcModel.gAR.getDimOutput(), cpcModel.gEncoder.getDimOutput()

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    downsampling = cpcModel.cpc.gEncoder.DOWNSAMPLING if isinstance(cpcModel, model.CPCModelNullspace) else cpcModel.gEncoder.DOWNSAMPLING
    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],  downsampling,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, downsampling,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()
    
    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params, lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer and not args.onlyCapture and not args.only_classif_metric:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None and not args.onlyCapture and not args.only_classif_metric:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")
        with open(args.pathCheckpoint + "_args.json", 'w') as file:
            json.dump(vars(args), file, indent=2)

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                           lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                                                               n_epoch, epoch),
                                                           last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        print(f'Redoing {len(logs["epoch"])} scheduler steps')
        for i in range(len(logs["epoch"])):
            scheduler.step()

    print("cpcModel", cpcModel)
    print("cpcCriterion", cpcCriterion)

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()
    
    if args.supervised_classif_metric:

        linsep_batch_size = args.linsepBatchSizeGPU * args.nGPU

        dim_features = CPChiddenEncoder if args.phone_get_encoded else CPChiddenGar
        dim_ctx_features = CPChiddenGar  # for speakers using CNN encodings is not supported; could add but not very useful perhaps

        phoneLabelsData = None
        if args.path_phone_data:
            phoneLabelsData, nPhonesInData = parseSeqLabels(args.path_phone_data)
            
            if not args.CTCphones:
                print(f"Running phone separability with aligned phones")
            else:
                print(f"Running phone separability with CTC loss")

            def constructPhoneCriterionAndOptimizer():
                if not args.CTCphones:
                    # print(f"Running phone separability with aligned phones")
                    phone_criterion = cr.PhoneCriterion(dim_features,
                                                nPhonesInData, args.phone_get_encoded,
                                                nLayers=args.linsep_net_layers)
                else:
                    # print(f"Running phone separability with CTC loss")
                    phone_criterion = cr.CTCPhoneCriterion(dim_features,
                                                    nPhonesInData, args.phone_get_encoded,
                                                    nLayers=args.linsep_net_layers)
                phone_criterion.cuda()
                phone_criterion = torch.nn.DataParallel(phone_criterion, device_ids=range(args.nGPU))

                # Optimizer
                phone_g_params = list(phone_criterion.parameters())

                phone_optimizer = torch.optim.Adam(phone_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)
                
                return phone_criterion, phone_optimizer
        
        if args.speaker_sep:
            print(f"Running speaker separability")

            def constructSpeakerCriterionAndOptimizer():
                speaker_criterion = cr.SpeakerCriterion(dim_ctx_features, len(speakers),
                                                        nLayers=args.linsep_net_layers)
                speaker_criterion.cuda()
                speaker_criterion = torch.nn.DataParallel(speaker_criterion, device_ids=range(args.nGPU))

                speaker_g_params = list(speaker_criterion.parameters())

                speaker_optimizer = torch.optim.Adam(speaker_g_params, lr=args.linsep_lr,
                                            betas=(args.linsep_beta1, args.linsep_beta2),
                                            eps=args.linsep_epsilon)

                return speaker_criterion, speaker_optimizer

        linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
                                phoneLabelsData, len(speakers))
        linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
                                    phoneLabelsData, len(speakers))

        linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True,
                                        numWorkers=0)

        linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False,
                                    numWorkers=0)

        def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch):
            log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch))
            if not os.path.exists(log_path_for_epoch):
                os.makedirs(log_path_for_epoch)
            log_path_phoneme = os.path.join(log_path_for_epoch, "phoneme/")
            log_path_speaker = os.path.join(log_path_for_epoch, "speaker/")
            if not os.path.exists(log_path_phoneme):
                os.makedirs(log_path_phoneme)
            if not os.path.exists(log_path_speaker):
                os.makedirs(log_path_speaker)
            if args.linsep_checkpoint_dir:
                checpoint_path_for_epoch = os.path.join(args.linsep_checkpoint_dir, str(numOfEpoch))
                checkpoint_path_phoneme = os.path.join(checpoint_path_for_epoch, "phoneme/")
                checkpoint_path_speaker = os.path.join(checpoint_path_for_epoch, "speaker/")
                if not os.path.exists(checkpoint_path_phoneme):
                    os.makedirs(checkpoint_path_phoneme)
                if not os.path.exists(checkpoint_path_speaker):
                    os.makedirs(checkpoint_path_speaker)
            locLogsPhone = {}
            locLogsSpeaker = {}
            if args.path_phone_data:
                phone_criterion, phone_optimizer = constructPhoneCriterionAndOptimizer()
                locLogsPhone = linsep.trainLinsepClassification(
                    cpcMdl,
                    phone_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    phone_optimizer,
                    log_path_phoneme,
                    args.linsep_task_logging_step,
                    checkpoint_path_phoneme,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'phone')
                del phone_criterion
                del phone_optimizer
            if args.speaker_sep:
                speaker_criterion, speaker_optimizer = constructSpeakerCriterionAndOptimizer()
                locLogsSpeaker = linsep.trainLinsepClassification(
                    cpcMdl,
                    speaker_criterion,  # combined with classification model before
                    linsep_train_loader,
                    linsep_val_loader,
                    speaker_optimizer,
                    log_path_speaker,
                    args.linsep_task_logging_step,
                    checkpoint_path_speaker,
                    args.linsep_n_epoch,
                    cpcStateEpoch,
                    'speaker')
                del speaker_criterion
                del speaker_optimizer

            locLogsPhone = {"phone_" + k: v for k, v in locLogsPhone.items()}
            locLogsSpeaker = {"speaker_" + k: v for k, v in locLogsSpeaker.items()}
            return {**locLogsPhone, **locLogsSpeaker}

        linsepClassificationTaskConfig = (args.linsep_classif_each_epochs,
                                            runLinsepClassificationTraining)

    else:
        linsepClassificationTaskConfig = (None, None)

    if not args.onlyCapture and not args.only_classif_metric:
        run(trainDataset,
            valDataset,
            (captureDataset, captureOptions, captureSetStatsCollector),
            linsepClassificationTaskConfig,
            batchSize,
            args.samplingType,
            cpcModel,
            cpcCriterion,
            args.nEpoch,
            args.pathCheckpoint,
            optimizer,
            scheduler,
            logs)
    if args.onlyCapture:  
    # caution [!] - will capture for last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        onlyCapture(
            (captureDataset, captureOptions, captureSetStatsCollector),
            batchSize,
            cpcModel,
            cpcCriterion,
            logs)
    if args.only_classif_metric:
    # caution [!] - will use last checkpoint (last saved state) if checkpoint directory given
    #               to use specific checkpoint provide full checkpoint file path
    #               will use "last state" and not "best in internal CPC accuracy" anyway
        trainedEpoch = len(logs["epoch"]) - 1
        # runPhonemeClassificationTraining created above if args.supervised_classif_metric
        runLinsepClassificationTraining(trainedEpoch, cpcModel, trainedEpoch)
예제 #12
0
def main():
    # Parse and print args
    args = parse_args()
    logger.info(args)

    # Load the model
    print("")
    print(f"Loading model from {args.path_checkpoint}")
    if args.model == "cpc":
        sys.path.append(os.path.abspath(args.path_cpc))
        from cpc.feature_loader import loadModel, FeatureModule
        model = loadModel([args.path_checkpoint])[0]
    else:
        from fairseq import checkpoint_utils

        def loadCheckpoint(path_checkpoint, path_data):
            """
            Load lstm_lm model from checkpoint.
            """
            # Set up the args Namespace
            model_args = argparse.Namespace(
                task="language_modeling",
                output_dictionary_size=-1,
                data=path_data,
                path=path_checkpoint
                )
            
            # Load model
            models, _model_args = checkpoint_utils.load_model_ensemble([model_args.path])
            model = models[0]
            return model

        model = loadCheckpoint(args.path_checkpoint, args.path_data)
    
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"

    # Register the hooks
    layer_outputs = {}
    def get_layer_output(name):
        def hook(model, input, output):
            if type(output) is tuple:
                layer_outputs[name] = output[0].detach().squeeze(1).cpu().numpy()
            elif type(output) is dict:
                layer_outputs[name] = output["x"].detach().squeeze(0).cpu().numpy()
            else:
                layer_outputs[name] = output.detach().squeeze(0).cpu().numpy()
        return hook

    layer_names = []
    if args.model == "cpc":
        layer_name = os.path.basename(os.path.dirname(args.path_checkpoint))
        layer_names.append(layer_name)
        model.gAR.register_forward_hook(get_layer_output(layer_name))
    else:
        for i in range(len(model.encoder.layers)):
            layer_name = "layer_{}".format(i)
            layer_names.append(layer_name)
            model.encoder.layers[i].register_forward_hook(get_layer_output(layer_name))
        layer_name = "last"
        layer_names.append(layer_name)
        model.register_forward_hook(get_layer_output(layer_name))

    model = model.eval().to(device)  
    print("Model loaded!")
    print(model)

    # Extract values from chosen layers and save them to files
    phonetic = "phonetic"
    datasets_path = os.path.join(args.path_data, phonetic)
    datasets = os.listdir(datasets_path)
    print(datasets)

    with torch.no_grad():     
        for dataset in datasets:
            print("> {}".format(dataset))
            dataset_path = os.path.join(datasets_path, dataset)
            files = [f for f in os.listdir(dataset_path) if f.endswith(args.file_extension)]
            for i, f in enumerate(files):
                print("Progress {:2.1%}".format(i / len(files)), end="\r")
                input_f = os.path.join(dataset_path, f)
                x, sample_rate = sf.read(input_f)
                x = torch.tensor(x).float().reshape(1,-1).to(device)
                
                if args.model == "cpc":
                    encodedData = model.gEncoder(x.unsqueeze(1)).permute(0, 2, 1)
                    output = model.gAR(encodedData)
                else:
                    output = model(x, features_only=True)["x"]

                for layer_name, value in layer_outputs.items():
                    output_dir = os.path.join(args.path_output_dir, layer_name, phonetic, dataset)
                    Path(output_dir).mkdir(parents=True, exist_ok=True)
                    out_f = os.path.join(output_dir, os.path.splitext(f)[0] + ".txt")
                    np.savetxt(out_f, value)
예제 #13
0
def main(args):
    args = parseArgs(args)

    utils.set_seed(args.random_seed)
    logs = {"epoch": [], "iter": [], "saveStep": args.save_step}
    loadOptimizer = False
    if args.pathCheckpoint is not None and not args.restart:
        cdata = fl.getCheckpointData(args.pathCheckpoint)
        if cdata is not None:
            data, logs, locArgs = cdata
            print(f"Checkpoint detected at {data}")
            fl.loadArgs(args,
                        locArgs,
                        forbiddenAttr={
                            "nGPU", "pathCheckpoint", "debug", "restart",
                            "world_size", "n_nodes", "node_id",
                            "n_gpu_per_node", "max_size_loaded"
                        })
            args.load, loadOptimizer = [data], True
            args.loadCriterion = True

    logs["logging_step"] = args.logging_step

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    seqNames, speakers = findAllSeqs(args.pathDB,
                                     extension=args.file_extension,
                                     loadCache=not args.ignore_cache)

    print(f'Found files: {len(seqNames)} seqs, {len(speakers)} speakers')
    # Datasets
    if args.pathTrain is not None:
        seqTrain = filterSeqs(args.pathTrain, seqNames)
    else:
        seqTrain = seqNames

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.99 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
        print(f'Found files: {len(seqTrain)} train, {len(seqVal)} val')
    else:
        seqVal = filterSeqs(args.pathVal, seqNames)

    if args.debug:
        seqTrain = seqTrain[-1000:]
        seqVal = seqVal[-100:]

    phoneLabels, nPhones = None, None
    if args.supervised and args.pathPhone is not None:
        print("Loading the phone labels at " + args.pathPhone)
        phoneLabels, nPhones = parseSeqLabels(args.pathPhone)
        print(f"{nPhones} phones found")

    print("")
    print(f'Loading audio data at {args.pathDB}')
    print("Loading the training dataset")
    trainDataset = AudioBatchData(args.pathDB,
                                  args.sizeWindow,
                                  seqTrain,
                                  phoneLabels,
                                  len(speakers),
                                  nProcessLoader=args.n_process_loader,
                                  MAX_SIZE_LOADED=args.max_size_loaded)
    print("Training dataset loaded")
    print("")

    print("Loading the validation dataset")
    valDataset = AudioBatchData(args.pathDB,
                                args.sizeWindow,
                                seqVal,
                                phoneLabels,
                                len(speakers),
                                nProcessLoader=args.n_process_loader)
    print("Validation dataset loaded")
    print("")

    if args.load is not None:
        cpcModel, args.hiddenGar, args.hiddenEncoder = \
            fl.loadModel(args.load)

    else:
        # Encoder network
        encoderNet = fl.getEncoder(args)
        # AR Network
        arNet = fl.getAR(args)

        cpcModel = model.CPCModel(encoderNet, arNet)

    batchSize = args.nGPU * args.batchSizeGPU
    cpcModel.supervised = args.supervised

    # Training criterion
    if args.load is not None and args.loadCriterion:
        cpcCriterion = loadCriterion(args.load[0],
                                     cpcModel.gEncoder.DOWNSAMPLING,
                                     len(speakers), nPhones)
    else:
        cpcCriterion = getCriterion(args, cpcModel.gEncoder.DOWNSAMPLING,
                                    len(speakers), nPhones)

    if loadOptimizer:
        state_dict = torch.load(args.load[0], 'cpu')
        cpcCriterion.load_state_dict(state_dict["cpcCriterion"])

    cpcCriterion.cuda()
    cpcModel.cuda()

    # Optimizer
    g_params = list(cpcCriterion.parameters()) + list(cpcModel.parameters())

    lr = args.learningRate
    optimizer = torch.optim.Adam(g_params,
                                 lr=lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    if loadOptimizer:
        print("Loading optimizer " + args.load[0])
        state_dict = torch.load(args.load[0], 'cpu')
        if "optimizer" in state_dict:
            optimizer.load_state_dict(state_dict["optimizer"])

    # Checkpoint
    if args.pathCheckpoint is not None:
        if not os.path.isdir(args.pathCheckpoint):
            os.mkdir(args.pathCheckpoint)
        args.pathCheckpoint = os.path.join(args.pathCheckpoint, "checkpoint")

    scheduler = None
    if args.schedulerStep > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    args.schedulerStep,
                                                    gamma=0.5)
    if args.schedulerRamp is not None:
        n_epoch = args.schedulerRamp
        print(f"Ramp activated. n_e = {n_epoch}")
        scheduler_ramp = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda epoch: utils.ramp_scheduling_function(
                n_epoch, epoch),
            last_epoch=-1)
        if scheduler is None:
            scheduler = scheduler_ramp
        else:
            scheduler = utils.SchedulerCombiner([scheduler_ramp, scheduler],
                                                [0, args.schedulerRamp])
    if scheduler is not None:
        for i in range(len(logs["epoch"])):
            scheduler.step()

    cpcModel = torch.nn.DataParallel(cpcModel,
                                     device_ids=range(args.nGPU)).cuda()
    cpcCriterion = torch.nn.DataParallel(cpcCriterion,
                                         device_ids=range(args.nGPU)).cuda()

    run(trainDataset, valDataset, batchSize, args.samplingType, cpcModel,
        cpcCriterion, args.nEpoch, args.pathCheckpoint, optimizer, scheduler,
        logs)