コード例 #1
0
def save_results(respfile, Yhat, S2, maskvol, Z=None, outputsuffix=None, 
                 results=None, save_path=''):
    
    print("Writing outputs ...")
    if respfile is None:
        exfile = None
        file_ext = '.pkl'
    else:
        if fileio.file_type(respfile) == 'cifti' or \
           fileio.file_type(respfile) == 'nifti':
            exfile = respfile
        else:
            exfile = None
        file_ext = fileio.file_extension(respfile)

    if outputsuffix is not None:
        ext = str(outputsuffix) + file_ext
    else:
        ext = file_ext

    fileio.save(Yhat, os.path.join(save_path, 'yhat' + ext), example=exfile, 
                                   mask=maskvol)
    fileio.save(S2, os.path.join(save_path, 'ys2' + ext), example=exfile, 
                mask=maskvol)
    if Z is not None:
        fileio.save(Z, os.path.join(save_path, 'Z' + ext), example=exfile, 
                    mask=maskvol)

    if results is not None:        
        for metric in list(results.keys()):
            fileio.save(results[metric], os.path.join(save_path, metric + ext), 
                        example=exfile, mask=maskvol)
コード例 #2
0
def collect_nm(processing_dir,
               job_name,
               func='estimate',
               collect=False,
               binary=False,
               batch_size=None,
               outputsuffix='_estimate'):
    
    """This function checks and collects all batches.

    ** Input:
        * processing_dir        -> Full path to the processing directory
        * collect               -> If True data is checked for failed batches
                                and collected; if False data is just checked

    ** Output:
        * Text files containing all results accross all batches the combined
          output

    written by (primarily) T Wolfers, (adapted) SM Kia
    """

    if binary:
        file_extentions = '.pkl'
    else:
        file_extentions = '.txt'

    # detect number of subjects, batches, hyperparameters and CV
    batches = glob.glob(processing_dir + 'batch_*/')
    
    count = 0
    batch_fail = []
    
    if func != 'fit':
        file_example = []
        for batch in batches:
            if file_example == []:
                file_example = glob.glob(batch + 'yhat' + outputsuffix + file_extentions)
            else:
                break
        if binary is False:
            file_example = fileio.load(file_example[0])
        else:
            file_example = pd.read_pickle(file_example[0])
        numsubjects = file_example.shape[0]
        batch_size = file_example.shape[1]
    
        # artificially creates files for batches that were not executed
        batch_dirs = glob.glob(processing_dir + 'batch_*/')
        batch_dirs = fileio.sort_nicely(batch_dirs)
        for batch in batch_dirs:
            filepath = glob.glob(batch + 'yhat' + outputsuffix + '*')
            if filepath == []:
                count = count+1
                batch1 = glob.glob(batch + '/' + job_name + '*.sh')
                print(batch1)
                batch_fail.append(batch1)
                if collect is True:
                    pRho = np.ones(batch_size)
                    pRho = pRho.transpose()
                    pRho = pd.Series(pRho)
                    fileio.save(pRho, batch + 'pRho' + outputsuffix + file_extentions)
                    
                    Rho = np.zeros(batch_size)
                    Rho = Rho.transpose()
                    Rho = pd.Series(Rho)
                    fileio.save(Rho, batch + 'Rho' + outputsuffix + file_extentions)
                    
                    rmse = np.zeros(batch_size)
                    rmse = rmse.transpose()
                    rmse = pd.Series(rmse)
                    fileio.save(rmse, batch + 'RMSE' + outputsuffix + file_extentions)
                    
                    smse = np.zeros(batch_size)
                    smse = smse.transpose()
                    smse = pd.Series(smse)
                    fileio.save(smse, batch + 'SMSE' + outputsuffix + file_extentions)
                    
                    expv = np.zeros(batch_size)
                    expv = expv.transpose()
                    expv = pd.Series(expv)
                    fileio.save(expv, batch + 'EXPV' + outputsuffix + file_extentions)
                    
                    msll = np.zeros(batch_size)
                    msll = msll.transpose()
                    msll = pd.Series(msll)
                    fileio.save(msll, batch + 'MSLL' + outputsuffix + file_extentions)
    
                    yhat = np.zeros([numsubjects, batch_size])
                    yhat = pd.DataFrame(yhat)
                    fileio.save(yhat, batch + 'yhat' + outputsuffix + file_extentions)
    
                    ys2 = np.zeros([numsubjects, batch_size])
                    ys2 = pd.DataFrame(ys2)
                    fileio.save(ys2, batch + 'ys2' + outputsuffix + file_extentions)
    
                    Z = np.zeros([numsubjects, batch_size])
                    Z = pd.DataFrame(Z)
                    fileio.save(Z, batch + 'Z' + outputsuffix + file_extentions)
    
                    if not os.path.isdir(batch + 'Models'):
                        os.mkdir('Models')
                        
                        
            else: # if more than 10% of yhat is nan then consider the batch as a failed batch
                yhat = fileio.load(filepath[0])
                if np.count_nonzero(~np.isnan(yhat))/(np.prod(yhat.shape))<0.9:
                    count = count+1
                    batch1 = glob.glob(batch + '/' + job_name + '*.sh')
                    print('More than 10% nans in '+ batch1[0])
                    batch_fail.append(batch1)
    
    # combines all output files across batches
    if collect is True:
        pRho_filenames = glob.glob(processing_dir + 'batch_*/' + 'pRho' + 
                                   outputsuffix + '*')
        if pRho_filenames:
            pRho_filenames = fileio.sort_nicely(pRho_filenames)
            pRho_dfs = []
            for pRho_filename in pRho_filenames:
                pRho_dfs.append(pd.DataFrame(fileio.load(pRho_filename)))
            pRho_dfs = pd.concat(pRho_dfs, ignore_index=True, axis=0)
            fileio.save(pRho_dfs, processing_dir + 'pRho' + outputsuffix +
                        file_extentions)
            del pRho_dfs

        Rho_filenames = glob.glob(processing_dir + 'batch_*/' + 'Rho' + 
                                   outputsuffix + '*')
        if Rho_filenames:
            Rho_filenames = fileio.sort_nicely(Rho_filenames)
            Rho_dfs = []
            for Rho_filename in Rho_filenames:
                Rho_dfs.append(pd.DataFrame(fileio.load(Rho_filename)))
            Rho_dfs = pd.concat(Rho_dfs, ignore_index=True, axis=0)
            fileio.save(Rho_dfs, processing_dir + 'Rho' + outputsuffix +
                        file_extentions)
            del Rho_dfs

        Z_filenames = glob.glob(processing_dir + 'batch_*/' + 'Z' + 
                                   outputsuffix + '*')
        if Z_filenames:
            Z_filenames = fileio.sort_nicely(Z_filenames)
            Z_dfs = []
            for Z_filename in Z_filenames:
                Z_dfs.append(pd.DataFrame(fileio.load(Z_filename)))
            Z_dfs = pd.concat(Z_dfs, ignore_index=True, axis=1)
            fileio.save(Z_dfs, processing_dir + 'Z' + outputsuffix +
                        file_extentions)
            del Z_dfs
            
        yhat_filenames = glob.glob(processing_dir + 'batch_*/' + 'yhat' + 
                                   outputsuffix + '*')
        if yhat_filenames:
            yhat_filenames = fileio.sort_nicely(yhat_filenames)
            yhat_dfs = []
            for yhat_filename in yhat_filenames:
                yhat_dfs.append(pd.DataFrame(fileio.load(yhat_filename)))
            yhat_dfs = pd.concat(yhat_dfs, ignore_index=True, axis=1)
            fileio.save(yhat_dfs, processing_dir + 'yhat' + outputsuffix +
                        file_extentions)
            del yhat_dfs

        ys2_filenames = glob.glob(processing_dir + 'batch_*/' + 'ys2' + 
                                   outputsuffix + '*')
        if ys2_filenames:
            ys2_filenames = fileio.sort_nicely(ys2_filenames)
            ys2_dfs = []
            for ys2_filename in ys2_filenames:
                ys2_dfs.append(pd.DataFrame(fileio.load(ys2_filename)))
            ys2_dfs = pd.concat(ys2_dfs, ignore_index=True, axis=1)
            fileio.save(ys2_dfs, processing_dir + 'ys2' + outputsuffix +
                        file_extentions)
            del ys2_dfs

        rmse_filenames = glob.glob(processing_dir + 'batch_*/' + 'RMSE' + 
                                   outputsuffix + '*')
        if rmse_filenames:
            rmse_filenames = fileio.sort_nicely(rmse_filenames)
            rmse_dfs = []
            for rmse_filename in rmse_filenames:
                rmse_dfs.append(pd.DataFrame(fileio.load(rmse_filename)))
            rmse_dfs = pd.concat(rmse_dfs, ignore_index=True, axis=0)
            fileio.save(rmse_dfs, processing_dir + 'RMSE' + outputsuffix +
                        file_extentions)
            del rmse_dfs

        smse_filenames = glob.glob(processing_dir + 'batch_*/' + 'SMSE' + 
                                   outputsuffix + '*')
        if smse_filenames:
            smse_filenames = fileio.sort_nicely(smse_filenames)
            smse_dfs = []
            for smse_filename in smse_filenames:
                smse_dfs.append(pd.DataFrame(fileio.load(smse_filename)))
            smse_dfs = pd.concat(smse_dfs, ignore_index=True, axis=0)
            fileio.save(smse_dfs, processing_dir + 'SMSE' + outputsuffix +
                        file_extentions)
            del smse_dfs
            
        expv_filenames = glob.glob(processing_dir + 'batch_*/' + 'EXPV' + 
                                   outputsuffix + '*')
        if expv_filenames:
            expv_filenames = fileio.sort_nicely(expv_filenames)
            expv_dfs = []
            for expv_filename in expv_filenames:
                expv_dfs.append(pd.DataFrame(fileio.load(expv_filename)))
            expv_dfs = pd.concat(expv_dfs, ignore_index=True, axis=0)
            fileio.save(expv_dfs, processing_dir + 'EXPV' + outputsuffix +
                        file_extentions)
            del expv_dfs
            
        msll_filenames = glob.glob(processing_dir + 'batch_*/' + 'MSLL' + 
                                   outputsuffix + '*')
        if msll_filenames:
            msll_filenames = fileio.sort_nicely(msll_filenames)
            msll_dfs = []
            for msll_filename in msll_filenames:
                msll_dfs.append(pd.DataFrame(fileio.load(msll_filename)))
            msll_dfs = pd.concat(msll_dfs, ignore_index=True, axis=0)
            fileio.save(msll_dfs, processing_dir + 'MSLL' + outputsuffix +
                        file_extentions)
            del msll_dfs
        
        if func != 'predict' and func != 'transfer':
            if not os.path.isdir(processing_dir + 'Models') and \
               os.path.exists(os.path.join(batches[0], 'Models')):
                os.mkdir(processing_dir + 'Models')
                
            meta_filenames = glob.glob(processing_dir + 'batch_*/Models/' + 'meta_data.md')
            mY = []
            sY = []
            mX = []
            sX = []
            if meta_filenames:
                meta_filenames = fileio.sort_nicely(meta_filenames)
                with open(meta_filenames[0], 'rb') as file:
                    meta_data = pickle.load(file)
                if meta_data['standardize']:
                    for meta_filename in meta_filenames:
                        mY.append(meta_data['mean_resp'])
                        sY.append(meta_data['std_resp'])
                        mX.append(meta_data['mean_cov'])
                        sX.append(meta_data['std_cov'])
                    meta_data['mean_resp'] = np.stack(mY) 
                    meta_data['std_resp'] = np.stack(sY) 
                    meta_data['mean_cov'] = np.stack(mX) 
                    meta_data['std_cov'] = np.stack(sX) 
                    
                with open(os.path.join(processing_dir, 'Models', 'meta_data.md'), 
                          'wb') as file:
                    pickle.dump(meta_data, file)
            
            batch_dirs = glob.glob(processing_dir + 'batch_*/')
            if batch_dirs:
                batch_dirs = fileio.sort_nicely(batch_dirs)
                for b, batch_dir in enumerate(batch_dirs):
                    src_files = glob.glob(batch_dir + 'Models/*.pkl')
                    if src_files:
                        src_files = fileio.sort_nicely(src_files)
                        for f, full_file_name in enumerate(src_files):
                            if os.path.isfile(full_file_name):
                                file_name = full_file_name.split('/')[-1]
                                n = file_name.split('_')
                                n[-1] = str(b * batch_size + f) + '.pkl'
                                n = '_'.join(n)
                                shutil.copy(full_file_name, processing_dir + 'Models/' + n)
                    elif func=='fit':
                        count = count+1
                        batch1 = glob.glob(batch_dir + '/' + job_name + '*.sh')
                        print('Failed batch: ' + batch1[0])
                        batch_fail.append(batch1)
                        
    # list batches that were not executed
    print('Number of batches that failed:' + str(count))
    batch_fail_df = pd.DataFrame(batch_fail)
    if file_extentions == '.txt':
        fileio.save_pd(batch_fail_df, processing_dir + 'failed_batches'+
                file_extentions)
    else:
        fileio.save(batch_fail_df, processing_dir +
            'failed_batches' +
            file_extentions)

    if not batch_fail:
        return 1
    else:
        return 0
コード例 #3
0
def estimate(covfile, respfile, **kwargs):
    """ Estimate a normative model

    This will estimate a model in one of two settings according to 
    theparticular parameters specified (see below)
        
    * under k-fold cross-validation.
      requires respfile, covfile and cvfolds>=2
    * estimating a training dataset then applying to a second test dataset.
      requires respfile, covfile, testcov and testresp.
    * estimating on a training dataset ouput of forward maps mean and se. 
      requires respfile, covfile and testcov

    The models are estimated on the basis of data stored on disk in ascii or
    neuroimaging data formats (nifti or cifti). Ascii data should be in
    tab or space delimited format with the number of subjects in rows and the
    number of variables in columns. Neuroimaging data will be reshaped
    into the appropriate format

    Basic usage::

        estimate(covfile, respfile, [extra_arguments])

    where the variables are defined below. Note that either the cfolds
    parameter or (testcov, testresp) should be specified, but not both.

    :param respfile: response variables for the normative model
    :param covfile: covariates used to predict the response variable
    :param maskfile: mask used to apply to the data (nifti only)
    :param cvfolds: Number of cross-validation folds
    :param testcov: Test covariates
    :param testresp: Test responses
    :param alg: Algorithm for normative model
    :param configparam: Parameters controlling the estimation algorithm
    :param saveoutput: Save the output to disk? Otherwise returned as arrays
    :param outputsuffix: Text string to add to the output filenames
    :param inscale: Scaling approach for input covariates, could be 'None' (Default), 
                    'standardize', 'minmax', or 'robminmax'.
    :param outscale: Scaling approach for output responses, could be 'None' (Default), 
                    'standardize', 'minmax', or 'robminmax'.

    All outputs are written to disk in the same format as the input. These are:

    :outputs: * yhat - predictive mean
              * ys2 - predictive variance
              * nm - normative model
              * Z - deviance scores
              * Rho - Pearson correlation between true and predicted responses
              * pRho - parametric p-value for this correlation
              * rmse - root mean squared error between true/predicted responses
              * smse - standardised mean squared error

    The outputsuffix may be useful to estimate multiple normative models in the
    same directory (e.g. for custom cross-validation schemes)
    """
    
    # parse keyword arguments 
    maskfile = kwargs.pop('maskfile',None)
    cvfolds = kwargs.pop('cvfolds', None)
    testcov = kwargs.pop('testcov', None)
    testresp = kwargs.pop('testresp',None)
    alg = kwargs.pop('alg','gpr')
    outputsuffix = kwargs.pop('outputsuffix','_estimate')
    inscaler = kwargs.pop('inscaler','None')
    outscaler = kwargs.pop('outscaler','None')
    warp = kwargs.get('warp', None)

    # convert from strings if necessary
    saveoutput = kwargs.pop('saveoutput','True')
    if type(saveoutput) is str:
        saveoutput = saveoutput=='True'
    savemodel = kwargs.pop('savemodel','False')
    if type(savemodel) is str:
        savemodel = savemodel=='True'
    
    if savemodel and not os.path.isdir('Models'):
        os.mkdir('Models')

    # load data
    print("Processing data in " + respfile)
    X = fileio.load(covfile)
    Y, maskvol = load_response_vars(respfile, maskfile)
    if len(Y.shape) == 1:
        Y = Y[:, np.newaxis]
    if len(X.shape) == 1:
        X = X[:, np.newaxis]
    Nmod = Y.shape[1]
    
    if (testcov is not None) and (cvfolds is None): # a separate test dataset
        
        run_cv = False
        cvfolds = 1
        Xte = fileio.load(testcov)
        if len(Xte.shape) == 1:
            Xte = Xte[:, np.newaxis]
        if testresp is not None:
            Yte, testmask = load_response_vars(testresp, maskfile)
            if len(Yte.shape) == 1:
                Yte = Yte[:, np.newaxis]
        else:
            sub_te = Xte.shape[0]
            Yte = np.zeros([sub_te, Nmod])
            
        # treat as a single train-test split
        testids = range(X.shape[0], X.shape[0]+Xte.shape[0])
        splits = CustomCV((range(0, X.shape[0]),), (testids,))

        Y = np.concatenate((Y, Yte), axis=0)
        X = np.concatenate((X, Xte), axis=0)
        
    else:
        run_cv = True
        # we are running under cross-validation
        splits = KFold(n_splits=cvfolds, shuffle=True)
        testids = range(0, X.shape[0])
        if alg=='hbr':
           trbefile = kwargs.get('trbefile', None) 
           if trbefile is not None:
                be = fileio.load(trbefile)
                if len(be.shape) == 1:
                    be = be[:, np.newaxis]
           else:
                print('No batch-effects file! Initilizing all as zeros!')
                be = np.zeros([X.shape[0],1])

    # find and remove bad variables from the response variables
    # note: the covariates are assumed to have already been checked
    nz = np.where(np.bitwise_and(np.isfinite(Y).any(axis=0),
                                 np.var(Y, axis=0) != 0))[0]

    # run cross-validation loop
    Yhat = np.zeros_like(Y)
    S2 = np.zeros_like(Y)
    Z = np.zeros_like(Y)
    nlZ = np.zeros((Nmod, cvfolds))
    
    scaler_resp = []
    scaler_cov = []
    mean_resp = [] # this is just for computing MSLL
    std_resp = []   # this is just for computing MSLL
    
    if warp is not None:
        Ywarp = np.zeros_like(Yhat)
        mean_resp_warp = [np.zeros(Y.shape[1]) for s in range(splits.n_splits)]
        std_resp_warp = [np.zeros(Y.shape[1]) for s in range(splits.n_splits)]

    for idx in enumerate(splits.split(X)):

        fold = idx[0]
        tr = idx[1][0]
        ts = idx[1][1]

        # standardize responses and covariates, ignoring invalid entries
        iy_tr, jy_tr = np.ix_(tr, nz)
        iy_ts, jy_ts = np.ix_(ts, nz)
        mY = np.mean(Y[iy_tr, jy_tr], axis=0)
        sY = np.std(Y[iy_tr, jy_tr], axis=0)
        mean_resp.append(mY)
        std_resp.append(sY)
        
        if inscaler in ['standardize', 'minmax', 'robminmax']:
            X_scaler = scaler(inscaler)
            Xz_tr = X_scaler.fit_transform(X[tr, :])
            Xz_ts = X_scaler.transform(X[ts, :])
            scaler_cov.append(X_scaler)
        else:
            Xz_tr = X[tr, :]
            Xz_ts = X[ts, :]
            
        if outscaler in ['standardize', 'minmax', 'robminmax']:
            Y_scaler = scaler(outscaler)
            Yz_tr = Y_scaler.fit_transform(Y[iy_tr, jy_tr])
            scaler_resp.append(Y_scaler)
        else:
            Yz_tr = Y[iy_tr, jy_tr]
        
        if (run_cv==True and alg=='hbr'):
            fileio.save(be[tr,:], 'be_kfold_tr_tempfile.pkl')
            fileio.save(be[ts,:], 'be_kfold_ts_tempfile.pkl')
            kwargs['trbefile'] = 'be_kfold_tr_tempfile.pkl'
            kwargs['tsbefile'] = 'be_kfold_ts_tempfile.pkl'
            
        # estimate the models for all subjects
        for i in range(0, len(nz)):  
            print("Estimating model ", i+1, "of", len(nz))
            nm = norm_init(Xz_tr, Yz_tr[:, i], alg=alg, **kwargs)
                
            try:
                nm = nm.estimate(Xz_tr, Yz_tr[:, i], **kwargs)     
                yhat, s2 = nm.predict(Xz_ts, Xz_tr, Yz_tr[:, i], **kwargs)
                
                if savemodel:
                    nm.save('Models/NM_' + str(fold) + '_' + str(nz[i]) + 
                            outputsuffix + '.pkl' )
                
                if outscaler == 'standardize': 
                    Yhat[ts, nz[i]] = Y_scaler.inverse_transform(yhat, index=i)
                    S2[ts, nz[i]] = s2 * sY[i]**2
                elif outscaler in ['minmax', 'robminmax']:
                    Yhat[ts, nz[i]] = Y_scaler.inverse_transform(yhat, index=i)
                    S2[ts, nz[i]] = s2 * (Y_scaler.max[i] - Y_scaler.min[i])**2
                else:
                    Yhat[ts, nz[i]] = yhat
                    S2[ts, nz[i]] = s2
                    
                nlZ[nz[i], fold] = nm.neg_log_lik
                
                if (run_cv or testresp is not None):
                    # warp the labels?
                    # TODO: Warping for scaled data
                    if warp is not None:
                        warp_param = nm.blr.hyp[1:nm.blr.warp.get_n_params()+1] 
                        Ywarp[ts, nz[i]] = nm.blr.warp.f(Y[ts, nz[i]], warp_param)
                        Ytest = Ywarp[ts, nz[i]]
                        
                        # Save warped mean of the training data (for MSLL)
                        yw = nm.blr.warp.f(Y[tr, nz[i]], warp_param)
                        mean_resp_warp[fold][i] = np.mean(yw)
                        std_resp_warp[fold][i] = np.std(yw)
                    else:
                        Ytest = Y[ts, nz[i]] 
                    
                    Z[ts, nz[i]] = (Ytest - Yhat[ts, nz[i]]) / \
                                    np.sqrt(S2[ts, nz[i]])       
                    
            except Exception as e:
                exc_type, exc_obj, exc_tb = sys.exc_info()
                fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                print("Model ", i+1, "of", len(nz),
                      "FAILED!..skipping and writing NaN to outputs")
                print("Exception:")
                print(e)
                print(exc_type, fname, exc_tb.tb_lineno)

                Yhat[ts, nz[i]] = float('nan')
                S2[ts, nz[i]] = float('nan')
                nlZ[nz[i], fold] = float('nan')
                if testcov is None:
                    Z[ts, nz[i]] = float('nan')
                else:
                    if testresp is not None:
                        Z[ts, nz[i]] = float('nan')


    if savemodel:
        print('Saving model meta-data...')
        with open('Models/meta_data.md', 'wb') as file:
            pickle.dump({'valid_voxels':nz, 'fold_num':cvfolds, 
                         'mean_resp':mean_resp, 'std_resp':std_resp, 
                         'scaler_cov':scaler_cov, 'scaler_resp':scaler_resp, 
                         'regressor':alg, 'inscaler':inscaler, 
                         'outscaler':outscaler}, file, protocol=PICKLE_PROTOCOL)    

    # compute performance metrics
    if (run_cv or testresp is not None):
        print("Evaluating the model ...")
        if warp is None:
            results = evaluate(Y[testids, :], Yhat[testids, :], 
                               S2=S2[testids, :], mY=mean_resp[0], 
                               sY=std_resp[0])
        else:
            results = evaluate(Ywarp[testids, :], Yhat[testids, :], 
                               S2=S2[testids, :], mY=mean_resp_warp[0], 
                               sY=std_resp_warp[0])
        
        
    # Set writing options
    if saveoutput:
        if (run_cv or testresp is not None):
            save_results(respfile, Yhat[testids, :], S2[testids, :], maskvol, 
                         Z=Z[testids, :], results=results, 
                         outputsuffix=outputsuffix)
            
        else:
            save_results(respfile, Yhat[testids, :], S2[testids, :], maskvol,
                         outputsuffix=outputsuffix)
                
    else:
        if (run_cv or testresp is not None):
            output = (Yhat[testids, :], S2[testids, :], nm, Z[testids, :], 
                      results)
        else:
            output = (Yhat[testids, :], S2[testids, :], nm)
        
        return output
コード例 #4
0
def estimate(args):
    torch.set_default_dtype(torch.float32)
    args.type = 'MT'
    print('Loading the input Data ...')
    responses = fileio.load_nifti(args.respfile,
                                  vol=True).transpose([3, 0, 1, 2])
    response_shape = responses.shape
    with open(args.covfile, 'rb') as handle:
        covariates = pickle.load(handle)['covariates']
    with open(args.testcovfile, 'rb') as handle:
        test_covariates = pickle.load(handle)['test_covariates']
    if args.mask is not None:
        mask = fileio.load_nifti(args.mask, vol=True)
        mask = fileio.create_mask(mask, mask=None)
    else:
        mask = fileio.create_mask(responses[0, :, :, :], mask=None)
    if args.testrespfile is not None:
        test_responses = fileio.load_nifti(args.testrespfile,
                                           vol=True).transpose([3, 0, 1, 2])
        test_responses_shape = test_responses.shape

    print('Normalizing the input Data ...')
    covariates_scaler = StandardScaler()
    covariates = covariates_scaler.fit_transform(covariates)
    test_covariates = covariates_scaler.transform(test_covariates)
    response_scaler = MinMaxScaler()
    responses = unravel_2D(response_scaler.fit_transform(ravel_2D(responses)),
                           response_shape)
    if args.testrespfile is not None:
        test_responses = unravel_2D(
            response_scaler.transform(ravel_2D(test_responses)),
            test_responses_shape)
        test_responses = np.expand_dims(test_responses, axis=1)

    factor = args.m

    x_context = np.zeros([covariates.shape[0], factor, covariates.shape[1]],
                         dtype=np.float32)
    y_context = np.zeros([
        responses.shape[0], factor, responses.shape[1], responses.shape[2],
        responses.shape[3]
    ],
                         dtype=np.float32)
    x_all = np.zeros([covariates.shape[0], factor, covariates.shape[1]],
                     dtype=np.float32)
    x_context_test = np.zeros(
        [test_covariates.shape[0], factor, test_covariates.shape[1]],
        dtype=np.float32)
    y_context_test = np.zeros([
        test_covariates.shape[0], factor, responses.shape[1],
        responses.shape[2], responses.shape[3]
    ],
                              dtype=np.float32)

    print('Estimating the fixed-effects ...')
    for i in range(factor):
        x_context[:, i, :] = covariates[:, :]
        x_context_test[:, i, :] = test_covariates[:, :]
        idx = np.random.randint(0, covariates.shape[0], covariates.shape[0])
        if args.estimator == 'ST':
            for j in range(responses.shape[1]):
                for k in range(responses.shape[2]):
                    for l in range(responses.shape[3]):
                        reg = LinearRegression()
                        reg.fit(x_context[idx, i, :], responses[idx, j, k, l])
                        y_context[:, i, j, k, l] = reg.predict(x_context[:,
                                                                         i, :])
                        y_context_test[:, i, j, k,
                                       l] = reg.predict(x_context_test[:,
                                                                       i, :])
        elif args.estimator == 'MT':
            reg = MultiTaskLasso(alpha=0.1)
            reg.fit(
                x_context[idx, i, :],
                np.reshape(responses[idx, :, :, :],
                           [covariates.shape[0],
                            np.prod(responses.shape[1:])]))
            y_context[:, i, :, :, :] = np.reshape(
                reg.predict(x_context[:, i, :]), [
                    x_context.shape[0], responses.shape[1], responses.shape[2],
                    responses.shape[3]
                ])
            y_context_test[:, i, :, :, :] = np.reshape(
                reg.predict(x_context_test[:, i, :]), [
                    x_context_test.shape[0], responses.shape[1],
                    responses.shape[2], responses.shape[3]
                ])
        print('Fixed-effect %d of %d is computed!' % (i + 1, factor))

    x_all = x_context
    responses = np.expand_dims(responses, axis=1).repeat(factor, axis=1)

    ################################## TRAINING #################################

    encoder = Encoder(x_context, y_context, args).to(args.device)
    args.cnn_feature_num = encoder.cnn_feature_num
    decoder = Decoder(x_context, y_context, args).to(args.device)
    model = NP(encoder, decoder, args).to(args.device)

    print('Estimating the Random-effect ...')
    k = 1
    epochs = [
        int(args.epochs / 4),
        int(args.epochs / 2),
        int(args.epochs / 5),
        int(args.epochs - args.epochs / 4 - args.epochs / 2 - args.epochs / 5)
    ]
    mini_batch_num = args.batchnum
    batch_size = int(x_context.shape[0] / mini_batch_num)
    model.train()
    for e in range(len(epochs)):
        optimizer = optim.Adam(model.parameters(), lr=10**(-e - 2))
        for j in range(epochs[e]):
            train_loss = 0
            rand_idx = np.random.permutation(x_context.shape[0])
            for i in range(mini_batch_num):
                optimizer.zero_grad()
                idx = rand_idx[i * batch_size:(i + 1) * batch_size]
                y_hat, z_all, z_context, dummy = model(
                    torch.tensor(x_context[idx, :, :], device=args.device),
                    torch.tensor(y_context[idx, :, :, :, :],
                                 device=args.device),
                    torch.tensor(x_all[idx, :, :], device=args.device),
                    torch.tensor(responses[idx, :, :, :, :],
                                 device=args.device))
                loss = np_loss(
                    y_hat,
                    torch.tensor(responses[idx, :, :, :, :],
                                 device=args.device), z_all, z_context)
                loss.backward()
                train_loss += loss.item()
                optimizer.step()
            print('Epoch: %d, Loss:%f, Average Loss:%f' %
                  (k, train_loss, train_loss / responses.shape[0]))
            k += 1

    ################################## Evaluation #################################

    print('Predicting on Test Data ...')
    model.eval()
    model.apply(apply_dropout_test)
    with torch.no_grad():
        y_hat, z_all, z_context, y_sigma = model(
            torch.tensor(x_context_test, device=args.device),
            torch.tensor(y_context_test, device=args.device),
            n=15)
    if args.testrespfile is not None:
        test_loss = np_loss(y_hat[0:test_responses_shape[0], :],
                            torch.tensor(test_responses, device=args.device),
                            z_all, z_context).item()
        print('Average Test Loss:%f' % (test_loss / test_responses_shape[0]))

        RMSE = np.sqrt(
            np.mean((test_responses -
                     y_hat[0:test_responses_shape[0], :].cpu().numpy())**2,
                    axis=0)).squeeze() * mask
        SMSE = RMSE**2 / np.var(test_responses, axis=0).squeeze()
        Rho, pRho = compute_pearsonr(
            test_responses.squeeze(),
            y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze())
        EXPV = explained_var(
            test_responses.squeeze(),
            y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze()) * mask
        MSLL = compute_MSLL(
            test_responses.squeeze(),
            y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze(),
            y_sigma[0:test_responses_shape[0], :].cpu().numpy().squeeze()**2,
            train_mean=test_responses.mean(0),
            train_var=test_responses.var(0)).squeeze() * mask

        NPMs = (test_responses -
                y_hat[0:test_responses_shape[0], :].cpu().numpy()) / (
                    y_sigma[0:test_responses_shape[0], :].cpu().numpy())
        NPMs = NPMs.squeeze()
        NPMs = NPMs * mask
        NPMs = np.nan_to_num(NPMs)

        temp = NPMs.reshape(
            [NPMs.shape[0], NPMs.shape[1] * NPMs.shape[2] * NPMs.shape[3]])
        EVD_params = extreme_value_prob_fit(temp, 0.01)
        abnormal_probs = extreme_value_prob(EVD_params, temp, 0.01)

    ############################## SAVING RESULTS #################################

    print('Saving Results to: %s' % (args.outdir))
    exfile = args.respfile
    y_hat = y_hat.squeeze().cpu().numpy()
    y_hat = response_scaler.inverse_transform(ravel_2D(y_hat))
    y_hat = y_hat[:, mask.flatten()]
    fileio.save(y_hat.T,
                args.outdir + '/yhat.nii.gz',
                example=exfile,
                mask=mask)
    ys2 = y_sigma.squeeze().cpu().numpy()
    ys2 = ravel_2D(ys2) * (response_scaler.data_max_ -
                           response_scaler.data_min_)
    ys2 = ys2**2
    ys2 = ys2[:, mask.flatten()]
    fileio.save(ys2.T, args.outdir + '/ys2.nii.gz', example=exfile, mask=mask)
    if args.testrespfile is not None:
        NPMs = ravel_2D(NPMs)[:, mask.flatten()]
        fileio.save(NPMs.T,
                    args.outdir + '/Z.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(Rho.flatten()[mask.flatten()],
                    args.outdir + '/Rho.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(pRho.flatten()[mask.flatten()],
                    args.outdir + '/pRho.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(RMSE.flatten()[mask.flatten()],
                    args.outdir + '/rmse.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(SMSE.flatten()[mask.flatten()],
                    args.outdir + '/smse.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(EXPV.flatten()[mask.flatten()],
                    args.outdir + '/expv.nii.gz',
                    example=exfile,
                    mask=mask)
        fileio.save(MSLL.flatten()[mask.flatten()],
                    args.outdir + '/msll.nii.gz',
                    example=exfile,
                    mask=mask)

    with open(args.outdir + 'model.pkl', 'wb') as handle:
        pickle.dump(
            {
                'model': model,
                'covariates_scaler': covariates_scaler,
                'response_scaler': response_scaler,
                'EVD_params': EVD_params,
                'abnormal_probs': abnormal_probs
            },
            handle,
            protocol=configs.PICKLE_PROTOCOL)


###############################################################################
    print('DONE!')