コード例 #1
0
def write_cati_csv():
    import pandas as pd
    data_path = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/CATI_datasets/'
    fcsv = data_path + 'all_cati.csv'
    res = pd.read_csv(fcsv)

    ser_dir = res.cenir_QC_path
    ser_dir = res.cenir_QC_path[res.globalQualitative > 3].values
    dcat = gdir(ser_dir, 'cat12')
    fT1 = gfile(dcat, '^s.*nii')
    fms = gfile(dcat, '^ms.*nii')
    fs_brain = gfile(dcat, '^brain_s.*nii')
    # return fT1, fms, fs_brain

    ind_perm = np.random.permutation(range(0, len(fT1)))
    itrain = ind_perm[0:100]
    ival = ind_perm[100:]

    dd = pd.DataFrame({'filename': fT1})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_T1.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_T1.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_T1.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fms})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_ms.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_ms.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_ms.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fs_brain})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_brain.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_brain.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_brain.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fT1})
    dd.to_csv(data_path + 'cati_cenir_all_T1.csv', index=False)
    dd = pd.DataFrame({'filename': fms})
    dd.to_csv(data_path + 'cati_cenir_all_ms.csv', index=False)
    dd = pd.DataFrame({'filename': fs_brain})
    dd.to_csv(data_path + 'cati_cenir_all_brain.csv', index=False)

    #add brain mask in csv
    allcsv = gfile('/home/romain.valabregue/datal/QCcnn/CATI_datasets',
                   'cati_cenir.*csv')

    for onecsv in allcsv:
        res = pd.read_csv(onecsv)
        resout = onecsv[:-4] + '_mask.csv'
        fmask = []
        for ft1 in res.filename:
            d = get_parent_path(ft1)[0]
            fmask += gfile(d, '^mask', opts={"items": 1})
        res['brain_mask'] = fmask
        res.to_csv(resout, index=False)
コード例 #2
0
def write_cati_csv():
    import pandas as pd
    data_path = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/CATI_datasets/'
    fcsv = data_path + 'all_cati.csv'
    res = pd.read_csv(fcsv)

    ser_dir = res.cenir_QC_path
    ser_dir = res.cenir_QC_path[res.globalQualitative > 3].values
    dcat = gdir(ser_dir, 'cat12')
    fT1 = gfile(dcat, '^s.*nii')
    fms = gfile(dcat, '^ms.*nii')
    fs_brain = gfile(dcat, '^brain_s.*nii')
    # return fT1, fms, fs_brain

    ind_perm = np.random.permutation(range(0, len(fT1)))
    itrain = ind_perm[0:100]
    ival = ind_perm[100:]

    dd = pd.DataFrame({'filename': fT1})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_T1.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_T1.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_T1.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fms})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_ms.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_ms.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_ms.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fs_brain})
    dd.to_csv(data_path + 'cati_cenir_QC4_all_brain.csv', index=False)
    dd.loc[ival, :].to_csv(data_path + 'cati_cenir_QC4_val_brain.csv',
                           index=False)
    dd.loc[itrain, :].to_csv(data_path + 'cati_cenir_QC4_train_brain.csv',
                             index=False)

    dd = pd.DataFrame({'filename': fT1})
    dd.to_csv(data_path + 'cati_cenir_all_T1.csv', index=False)
    dd = pd.DataFrame({'filename': fms})
    dd.to_csv(data_path + 'cati_cenir_all_ms.csv', index=False)
    dd = pd.DataFrame({'filename': fs_brain})
    dd.to_csv(data_path + 'cati_cenir_all_brain.csv', index=False)
コード例 #3
0
dataset = ImagesDataset(suj)
transforms = [
    #ZNormalization(verbose=verbose),
    RandomMotion(proportion_to_augment=1, seed=1, verbose=True),
]
sample = dataset[0]
for i, transform in enumerate(transforms):
    transformed = transform(sample)
    name = transform.__class__.__name__
    path = f'/tmp/{i}_{name}_abs.nii.gz'
    dataset.save_sample(transformed, dict(T1=path))

#histo normalization
from torchio.transforms.preprocessing.histogram_standardization import train, normalize

suj = gdir('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/HCPdata','^suj')
allfiles = gfile(suj,'^T1w_1mm.nii.gz')
allfiles_mask = gfile(suj,'^brain_T1w_1mm.nii.gz')
testf = allfiles[0:300]
outname ='/data/romain/data_exemple/landmarks_hcp300_res100.npy'
#outname ='/data/romain/data_exemple/landmarks_hcp300_res100_cutof01.npy'

landmark = train(testf, output_path=outname, mask_path=allfiles_mask, cutoff=(0, 1))

nii = nib.load(testf[0]).get_fdata(dtype=np.float32)
niim = normalize(nii, landmark)

perc_database=np.load(outname)


mm = np.mean(perc_database, axis=1)
コード例 #4
0
import pandas as pd
from script.create_jobs import create_jobs
from utils_file import get_parent_path, gfile, gdir
from utils import get_ep_iter_from_res_name

# parameters

root_dir = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_motion/'
prefix = "/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/job/job_eval_again/"

#model = root_dir + 'RegMotNew_mvt_train_hcp400_ms_B4_nw0_Size182_ConvN_C16_256_Lin40_50_D0_DC0.5_BN_Loss_L1_lr0.0001/'
model = root_dir + 'RegMotNew_ela1_train200_hcp400_ms_B4_nw0_Size182_ConvN_C16_256_Lin40_50_D0_BN_Loss_L1_lr0.0001/'
#model = root_dir +'RegMotNew_mvt_train_hcp400_ms_B4_nw0_Size182_ConvN_C16_256_Lin40_50_D0_BN_Loss_L1_lr0.0001'
#model = gdir(root_dir,'New_ela1_train_hcp400_T1')
#model = gdir(root_dir,'New_resc.*hcp.*T1.*DC0.1')
model = gdir(root_dir, 'rescal.*cati')
model = gdir(root_dir, 'RegMotNew_ela1.*cati_ms')
model = gdir(root_dir, 'Mask_rescale_ela1.*hcp.*T1')
model = gdir(root_dir, 'Mask.*ela1.*hcp.*ms')
model = gdir(root_dir, 'Reg.*D0')

#saved_models = gfile(model, '_ep(30|20|10)_.*000.pt$')
saved_models = gfile(model, '_ep9_.*000.pt$')
saved_models = saved_models[0:10]
#saved_models = gfile(model, '_ep(24|46|48)_.*000.pt$');
#saved_models = gfile(model, '_ep(10|11)_.*000.pt$');
#saved_models = gfile(model, '_ep(10|9)_.*000.pt$');
saved_models = gfile(model, '_ep([789]|10)_.*4500.pt$')
saved_models = []
for mm in model:
    ss_models = gfile(mm, '_ep.*pt$')
コード例 #5
0
    transforms = tc

doit.set_data_loader(train_csv_file=train_csv_file,
                     val_csv_file=val_csv_file,
                     transforms=transforms,
                     batch_size=batch_size,
                     num_workers=num_workers,
                     save_to_dir=load_from_dir[0],
                     replicate_suj=nb_replicate,
                     collate_fn=lambda x: x)

if do_eval:

    root_dir = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_random_noise/'
    model = gdir(root_dir, 'Reg.*D0_DC')

    # saved_models = []
    # for mm in model:
    #     ss_models = gfile(mm, '_ep.*pt$');
    #     nb_it=8000
    #     fresV_sorted, b, c = get_ep_iter_from_res_name(ss_models, nb_it)
    #     nb_iter = b * nb_it + c
    #     ii = np.where(nb_iter > 200000)[1:8]
    #     ss_models = list(ss_models[ii])
    #
    #     #ss_models = list(fresV_sorted[-8:])
    #
    #     saved_models = ss_models + saved_models
    saved_models = gfile(model, '_ep27_.*pt$')
コード例 #6
0
results_dirs = glob.glob('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_cnn/RES_1mm/eval_metric_on_pv/data_t*')
results_dirs = glob.glob('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_cnn/RES_1.4mm/eval_metric_on_pv/data_t*')
results_dirs = glob.glob('/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1mm/eval_metric_on_pv/data_t*')
results_dirs = glob.glob('/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1mm/eval_metric_on_bin/data_t*')

results_dirs = glob.glob('/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1.4mm/eval_metric_on_pv/data_t*')
results_dirs = glob.glob('/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1.4mm/eval_metric_on_bin/data_t*')

results_dirs += glob.glob('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_samseg/eval-1mm*')
results_dirs += glob.glob('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_samseg/eval-14mm*')


#Explore training curve
res='/home/fabien.girka/data/segmentation_tasks/RES_1.4mm/bin_synth_data_64_common_noise_no_gamma/results_cluster'
res = gdir('/home/fabien.girka/data/segmentation_tasks/RES_1mm/','data_64_common_noise_no_gamma')
res = gdir('/home/romain.valabregue/datal/PVsynth/training','pve')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES28mm','data')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES1mm','data')
res = gdir('/home/romain.valabregue/datal/PVsynth/training/RES_14mm_tissue','data')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES1mm_prob','pve_synth_mod3_P128$')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES1mm_prob','aniso')
res = gdir(res,'results_cluster')
res = ['/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/jzay/training/RES1mm_prob/pve_synth_mod3_P128_aniso_LogLkd_reg_multi/result',
       '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/jzay/training/RES1mm_prob/pve_synth_mod3_P128_aniso_LogLkd_reg_unis_lam1/results_cluster',
       '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/jzay/training/RES1mm_prob/pve_synth_mod3_P128_aniso_LogLkd_classif/results_cluster',
       ]
res = ['/home/romain.valabregue/datal/PVsynth/jzay/training/RES1mm_prob/pve_synth_mod3_P128/results_cluster/']
report_learning_curves(res)

コード例 #7
0
ファイル: test_predic_cnn.py プロジェクト: GFabien/torchQC
from torchvision.transforms import Compose

from torchio.data.io import write_image, read_image
from torchio.transforms import RandomMotionFromTimeCourse, RandomAffine, \
    CenterCropOrPad, RandomElasticDeformation, RandomElasticDeformation, CropOrPad, RandomNoise
from torchio import Image, ImagesDataset, transforms, INTENSITY, LABEL, Interpolation, Subject
from utils_file import get_parent_path, gfile, gdir
from doit_train import do_training, get_motion_transform
from slices_2 import do_figures_from_file
from utils import reduce_name_list, get_ep_iter_from_res_name

#Explore csv results
dqc = [
    '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_motion'
]
dres = gdir(dqc, 'RegMotNew.*train_hcp400_ms.*0001')
dres = gdir(dqc, 'RegMotNew.*hcp400_ms.*B4.*L1.*0001')
dres = gdir(dqc, 'R.*')
resname = get_parent_path(dres)[1]
#sresname = [rr[rr.find('hcp400_')+7: rr.find('hcp400_')+17] for rr in resname ]; sresname[2] += 'le-4'
sresname = resname
commonstr, sresname = reduce_name_list(sresname)
print('common str {}'.format(commonstr))

target = 'ssim'
target_scale = 1
#target='random_noise'; target_scale=10

legend_str = []
col = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
for ii, oneres in enumerate(dres):
コード例 #8
0
ファイル: test_val_data.py プロジェクト: SayJMC/torchQC
# CAT12
rescat = pd.read_csv(
    '/home/romain.valabregue/datal/QCcnn/CATI_datasets/res_cat12_suj18999.csv')
rescat.index = [sss.replace(';', '+')
                for sss in rescat.sujid]  # .values.replace(";","+")
rescat = rescat.loc[labelsujid]
print_accuracy_df(rescat, ytrue)
print_accuracy([rescat], ['IQR'],
               ytrue,
               prediction_name='IQR',
               inverse_prediction=False)

# READ RESULT cnn
prefix = "/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/"
resdir = "/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/predict_torch/"
rd = gdir(resdir, '^cati')
resfile = gfile(rd, 'all.csv')
resname = get_parent_path(resfile, 2)[1]
res = [pd.read_csv(f) for f in resfile]

for ii in range(len(resname)):
    print(ii)
    ssujid = []
    for ff in res[ii].fin:
        dd = ff.split('/')
        if dd[-1] is '': dd.pop()
        nn = len(dd)
        ssujid.append(dd[nn - 5] + '+' + dd[nn - 4] + '+' + dd[nn - 3])
    res[ii].index = ssujid
    if ii == 0: sujid = ssujid
    aa = set(ssujid).difference(set(labelsujid))
コード例 #9
0
ファイル: explore_seg_result.py プロジェクト: GReguig/torchQC
    '/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1.4mm/eval_metric_on_pv/data_t*'
)
results_dirs = glob.glob(
    '/home/romain.valabregue/datal/PVsynth/eval_cnn/RES_1.4mm/eval_metric_on_bin/data_t*'
)

results_dirs += glob.glob(
    '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_samseg/eval-1mm*'
)
results_dirs += glob.glob(
    '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/eval_samseg/eval-14mm*'
)

#Explore training curve
res = '/home/fabien.girka/data/segmentation_tasks/RES_1.4mm/bin_synth_data_64_common_noise_no_gamma/results_cluster'
res = gdir('/home/fabien.girka/data/segmentation_tasks/RES_1mm/',
           'data_64_common_noise_no_gamma')
res = gdir('/home/romain.valabregue/datal/PVsynth/training', 'pve')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES28mm',
           'data')
res = gdir('/home/romain.valabregue/datal/PVsynth/jzay/training/RES1mm',
           'data')
res = gdir('/home/romain.valabregue/datal/PVsynth/training/RES_14mm_tissue',
           'data')
res = gdir(res, 'resul')
report_learning_curves(res)

#explore synthetic data histogram
results_dirs = glob.glob(
    '/home/romain.valabregue/datal/PVsynth/RES_1.4mm/t*/513130')
f = gfile(results_dirs, '^5.*nii')
resname = get_parent_path(results_dirs, 2)[1]
コード例 #10
0
import torch.optim as optim
from torchvision.transforms import Compose

from torchio.data.io import write_image, read_image
from torchio.transforms import RandomMotionFromTimeCourse, RandomAffine, \
    CenterCropOrPad, RandomElasticDeformation, CropOrPad, RandomNoise, ApplyMask
from torchio import Image, ImagesDataset, transforms, INTENSITY, LABEL, Interpolation, Subject
from utils_file import get_parent_path, gfile, gdir
from doit_train import do_training, get_motion_transform
from slices_2 import do_figures_from_file
from utils import reduce_name_list, get_ep_iter_from_res_name, remove_extension


#Explore csv results
dqc = ['/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_motion']
dres = gdir(dqc,'RegMotNew.*train_hcp400_ms.*0001')
dres = gdir(dqc,'RegMotNew.*hcp400_ms.*B4.*L1.*0001')
dres = gdir(dqc,'R.*')

dres_reg_exp, figname = ['.*hcp.*ms', '.*hcp.*T1', 'cati.*ms', 'cati.*T1'], ['hcp_ms', 'hcp_T1', 'cati_ms', 'cati_T1']
dres_reg_exp, figname = [ 'cati.*ms', 'cati.*T1'], ['cati_ms', 'cati_T1']
for rr, fign in zip(dres_reg_exp, figname):
    dres = gdir(dqc, rr)
    resname = get_parent_path(dres)[1]
    print(len(resname)); print(resname)

    #sresname = [rr[rr.find('hcp400_')+7: rr.find('hcp400_')+17] for rr in resname ]; sresname[2] += 'le-4'
    sresname = resname
    commonstr, sresname = reduce_name_list(sresname)
    print('common str {}'.format(commonstr))
コード例 #11
0
#res_valOn
dd = gfile('/network/lustre/dtlake01/opendata/data/ds000030/rrr/CNN_cache_new',
           '_')
dir_fig = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_motion/figure/motion_regress/eval2/'
dir_fig = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/NN_regres_random_noise/figure2/'
data_name_list = get_parent_path(dd)[1]

dres_reg_exp, figname = ['Reg.*D0_DC0'], ['noise']
dres_reg_exp, figname = ['.*hcp.*ms', '.*hcp.*T1', 'cati.*ms', 'cati.*T1'
                         ], ['hcp_ms', 'hcp_T1', 'cati_ms', 'cati_T1']
sns.set(style="whitegrid")

csv_regex = 'res_valOn_'
for rr, fign in zip(dres_reg_exp, figname):
    dres = gdir(dqc, rr)
    resname = get_parent_path(dres)[1]
    resname = remove_string_from_name_list(resname, [
        'RegMotNew_', 'Size182_ConvN_C16_256_Lin40_50_', '_B4',
        '_Loss_L1_lr0.0001', '_nw0_D0'
    ])
    resname = [fign + '_' + zz for zz in resname]
    print('For {} found {} dir'.format(fign, len(resname)))

    if 0 == 1:
        for oneres, resn in zip(dres, resname):
            fres_valOn = gfile(oneres, csv_regex)
            print('Found {} <{}> for {} '.format(len(fres_valOn), csv_regex,
                                                 resn))
            if len(fres_valOn) == 0:
                continue
コード例 #12
0
# CAT12
rescat = pd.read_csv(
    '/home/romain.valabregue/datal/QCcnn/CATI_datasets/res_cat12_suj18999.csv')
rescat.index = [sss.replace(';', '+')
                for sss in rescat.sujid]  # .values.replace(";","+")
rescat = rescat.loc[labelsujid]
print_accuracy_df(rescat, ytrue)
print_accuracy([rescat], ['IQR'],
               ytrue,
               prediction_name='IQR',
               inverse_prediction=False)

# READ RESULT cnn
prefix = "/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/"
resdir = "/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/QCcnn/predict_torch/"
rd = gdir(resdir, '^cati')
resfile = gfile(rd, 'all.csv')
resname = get_parent_path(resfile, 2)[1]
res = [pd.read_csv(f) for f in resfile]

for ii in range(len(resname)):
    print(ii)
    ssujid = []
    for ff in res[ii].fin:
        dd = ff.split('/')
        if dd[-1] is '': dd.pop()
        nn = len(dd)
        ssujid.append(dd[nn - 5] + '+' + dd[nn - 4] + '+' + dd[nn - 3])
    res[ii].index = ssujid
    if ii == 0: sujid = ssujid
    aa = set(ssujid).difference(set(labelsujid))
コード例 #13
0
for k in conf_all.keys():
    filename = res+k+'.json'
    with open(filename, 'w') as file:
        json.dump(conf_all[k], file, indent=4, sort_keys=True)

filename = res+'grid_search.json'
with open(filename, 'w') as file:
    json.dump(gs, file, indent=4, sort_keys=True)


# generting jobs for validation
from utils_file import gdir, gfile, get_parent_path
f = gfile('/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/training/RES_1mm_tissue/pve_synth_data_92_common_noise_no_gamma/results_cluster',
          'model.*tar')
d='/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/PVsynth/training/RES_1mm_tissue/'
dres = gdir(d,['.*','result'])
dresname = get_parent_path(dres,level=2)[1]
dresname = [dd.split('_')[0] + '_' + dd.split('_')[1] for dd in dresname]

for one_res, resn in zip(dres, dresname):
    f = gfile(one_res,'model.*tar')
    fname = get_parent_path(f)[1]
    for ff in f:
        print('\"{}\",'.format(ff))

for one_res, resn in zip(dres, dresname):
    f = gfile(one_res,'model.*tar')
    fname = get_parent_path(f)[1]

    for ff in fname:
        fname_ep = ff.split('_')[1]
コード例 #14
0
ファイル: create_job_eval.py プロジェクト: SayJMC/torchQC
weights1 = prefix + "NN_saved_pytorch/torcho2_full_098_equal_BN05_b4_BCEWithLogitsLoss_SDG/quadriview_ep10.pt"
weights1 = prefix + "NN_saved_pytorch/modelV2_msbrain_098_equal_BN05_b1_BCEWithLogitsLoss_SDG/quadriview_ep10.pt"
weights1 = prefix + "NN_saved_pytorch/modelV2_msbrain_098_equal_BN05_b4_BCEWithLogitsLoss_SDG/quadriview_ep10.pt"
weights1 = prefix + "NN_saved_pytorch/modelV2_one256_msbrain_098_equal_BN0_b4_BCEWithLogitsLoss_SDG/quadriview_ep10.pt"
weights1 = prefix + "modelV2_last128_msbrain_098_equal_BN0_b4_BCEWithLogitsLoss_SDG/quadriview_ep10.pt"

name = "cati_modelV2_last128_msbrain_098_equal_BN05_b4_BCEWithLogitsLoss_SDG"
resdir = prefix + "predict_torch/" + name + '/'

py_options = ' --BN_momentum 0.5  --do_reslice --apply_mask --model_type=2'  # --use_gpu 0 '

# for CATI
tab = pd.read_csv(prefix + "CATI_datasets/all_cati.csv", index_col=0)
clip_val = tab.meanW.values + 3 * tab.stdW.values

dcat = gdir(tab.cenir_QC_path, 'cat12')
# dspm = gdir(tab.cenir_QC_path,'spm' )

fms = gfile(dcat, '^ms.*ni', opts={"items": 1})
fmask = gfile(dcat, '^mask_brain.*gz', opts={"items": 1})
# fms = gfile(dspm,'^ms.*ni',opts={"items":1})
faff = gfile(dcat, '^aff.*txt', opts={"items": 1})
fref = '/network/lustre/iss01/cenir/analyse/irm/users/romain.valabregue/HCPdata/suj_100307/T1w_1mm.nii.gz'

# for ABIDE
dcat = gdir('/network/lustre/dtlake01/opendata/data/ABIDE/cat12', ['^su', 'anat'])
# for ds30
dcat = gdir('/network/lustre/dtlake01/opendata/data/ds000030/cat12', ['^su', 'anat'])

# for validation
tab = pd.read_csv(prefix + "Motion_brain_ms_val_hcp200.csv", index_col=0)