def make_KKfiles_Script_supercomp(hybdatadict, SDparams,prb, detectioncrit, KKparams,supercomparams):
    '''Creates the files required to run KlustaKwik'''
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth' 
    
    KKhash = hash_utils.hash_dictionary_md5(KKparams)
    baselist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname'], KKhash]
    basefilename =  hash_utils.make_concatenated_filename(baselist)
    
    mainbasefilelist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname']]
    mainbasefilename = hash_utils.make_concatenated_filename(mainbasefilelist)
    
    DIRPATH = hybdatadict['output_path']
    os.chdir(DIRPATH)
    
    KKscriptname = basefilename
    make_KKscript_supercomp(KKparams,basefilename,KKscriptname,supercomparams)
    
    return basefilename
def analysis_ind_confKK(hybdatadict, SDparams,prb, detectioncrit, defaultKKparams, paramtochange, listparamvalues, detcrit = None):
    ''' Analyse results of one parameter family of KK jobs
        Not very different to the fucntion above only detcrit independent MKKbasefilenames'''
    outlistKK = rkk.one_param_varyKK_ind(hybdatadict, SDparams,prb, defaultKKparams, paramtochange, listparamvalues) 
    #outlistKK = [listbasefiles, outputdicts]
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth' 
        
    detcritclu =  detcrit_groundtruth['detected_groundtruth']
    
    NumSpikes = detcritclu.shape[0]
    
    cluKK = np.zeros((len(outlistKK[0]),NumSpikes))
    confusion = []
    for k, basefilename in enumerate(outlistKK[0]):
        clufile = hybdatadict['output_path'] + basefilename + '.clu.1'
        print os.path.isfile(clufile)
        if os.path.isfile(clufile):
            cluKK[k,:] =   np.loadtxt(clufile, dtype = np.int32, skiprows =1)
        else:
            print '%s does not exist '%(clufile)
        
        conf = get_confusion_matrix(cluKK[k,:],detcritclu)
        print conf
        confusion.append(conf)
        
    return confusion      
def pre_learn_data_grid(hybdatadict, SDparams,prb,detectioncrit,supervised_params):
    '''First this function will query whether the cached function: 
       detection_statistics.test_detection_algorithm(hybdatadict, SDparams, detectioncrit):, 
       has been called already with those arguments using `joblib_utils.is_cached`,
       If it has, it calls it to obtain detcrit_groundtruth.
       else if the hybrid dataset does not exist, it will raise an Error
       and tell you to run SpikeDetekt on the dataset.
    
       It scales the data using scale_data() 
    '''
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth'    
    #'detection_hashname'
    
    
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    DIRPATH = hybdatadict['output_path']
    with Experiment(hash_hyb_SD, dir= DIRPATH, mode='r') as expt:
        
        #Load the detcrit groundtruth
        #targetpathname = '/channel_groups/0/spikes/clusters' + '/' + detcrit_groundtruth['detection_hashname']
        targetpathname = detcrit_groundtruth['detection_hashname']
        targetsource = expt.channel_groups[0].spikes.clusters._get_child(targetpathname)
        
        #take the first supervised_params['numfirstspikes'] spikes only
        if supervised_params['numfirstspikes'] is not None: 
            fets = expt.channel_groups[0].spikes.features[0:supervised_params['numfirstspikes']]
            target = targetsource[0:supervised_params['numfirstspikes']]
        else: 
            fets = expt.channel_groups[0].spikes.features[:]
            target = targetsource[:]
        print expt        
        
            
            
    print 'fets.shape = ', fets.shape    
    print 'target.shape = ', target.shape
        
    if supervised_params['subvector'] is not None:
        subsetfets = fets[:,supervised_params['subvector']]
    else:
        subsetfets = fets
        
    scaled_fets = scale_data(subsetfets)
    classweights = compute_grid_weights(*supervised_params['grid_params'])
    #print classweights
    
    
    return hash_hyb_SD,classweights,scaled_fets, target
def create_confusion_matrix_fromclu_ind(hybdatadict, SDparams, prb, detectioncrit, KKparams):
    ''' will create the confusion matrix, using the equivalent to a clu file
      and detcrit groundtruth res and clu files, which is now contained in the kwik file 
       which will either be from KK or SVM and of the form: 
       Hash(hybdatadict)_Hash(sdparams)_Hash(detectioncrit)_KK_Hash(kkparams).kwik
        Hash(hybdatadict)_Hash(sdparams)_Hash(detectioncrit)_SVM_Hash(svmparams).kwik'''
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth' 
    
    
    
    #argKK = [hybdatadict, SDparams, prb, detectioncrit, KKparams]
    #print 'What the bloody hell is going on?'
    #if ju.is_cached(rkk.make_KKfiles_Script,*argKK):
    #    print 'Yes, you have created the scripts for running KK, which you have hopefully run!'
    #    basefilename = rkk.make_KKfiles_Script(hybdatadict, SDparams, prb, detectioncrit, KKparams)
    #else:
    #    print 'You need to run KK to generate a clu file '
        
    #print 'Did you even get here?'    
    
    basefilename = rkk.make_KKfiles_Script_detindep_full(hybdatadict, SDparams, prb, KKparams)
    
    DIRPATH = hybdatadict['output_path']
    KKclufile = DIRPATH+ basefilename + '.clu.1'    
    KKclusters = np.loadtxt(KKclufile,dtype=np.int32,skiprows=1)    
    
    conf = get_confusion_matrix(KKclusters, detcrit['detected_groundtruth'])
    
    return detcrit, KKclusters,conf
def make_KKfiles_viewer(hybdatadict, SDparams,prb, detectioncrit, KKparams):
    
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth'
        
    argKKfile = [hybdatadict, SDparams,prb, detectioncrit, KKparams]
    if ju.is_cached(make_KKfiles_Script,*argKKfile):
        print 'Yes, make_KKfiles_Script  has been run \n'
        
    else:
        print 'Need to run make_KKfiles_Script first, running now ' 
    basefilename = make_KKfiles_Script(hybdatadict, SDparams,prb, detectioncrit, KKparams)    
        
    mainbasefilelist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname']]
    mainbasefilename = hash_utils.make_concatenated_filename(mainbasefilelist)    
    
    DIRPATH = hybdatadict['output_path']
    os.chdir(DIRPATH)
    with Experiment(hash_hyb_SD, dir= DIRPATH, mode='r') as expt:
        if KKparams['numspikesKK'] is not None: 
            #spk = expt.channel_groups[0].spikes.waveforms_filtered[0:KKparams['numspikesKK'],:,:]
            res = expt.channel_groups[0].spikes.time_samples[0:KKparams['numspikesKK']]
            #fets = expt.channel_groups[0].spikes.features[0:KKparams['numspikesKK']]
            #fmasks = expt.channel_groups[0].spikes.features_masks[0:KKparams['numspikesKK'],:,1]
            
           # masks = expt.channel_groups[0].spikes.masks[0:KKparams['numspikesKK']]

        else: 
            #spk = expt.channel_groups[0].spikes.waveforms_filtered[:,:,:]
            res = expt.channel_groups[0].spikes.time_samples[:]
            #fets = expt.channel_groups[0].spikes.features[:]
            #fmasks = expt.channel_groups[0].spikes.features_masks[:,:,1]
            #print fmasks[3,:]
            #masks = expt.channel_groups[0].spikes.masks[:]
            
        mainresfile = DIRPATH + mainbasefilename + '.res.1' 
        mainspkfile = DIRPATH + mainbasefilename + '.spk.1'
        detcritclufilename = DIRPATH + mainbasefilename + '.detcrit.clu.1'
        trivialclufilename = DIRPATH + mainbasefilename + '.clu.1'
        write_res(res,mainresfile)
        write_trivial_clu(res,trivialclufilename)
        
       # write_spk_buffered(exptable,filepath, indices,
       #                buffersize=512)
        write_spk_buffered(expt.channel_groups[0].spikes.waveforms_filtered,
                            mainspkfile,
                           np.arange(len(res)))
        
        write_clu(detcrit_groundtruth['detected_groundtruth'], detcritclufilename)
            
        #s_total = SDparams['extract_s_before']+SDparams['extract_s_after']
            
        #write_xml(prb,
        #          n_ch = SDparams['nchannels'],
        #          n_samp = SDparams['S_TOTAL'],
        #          n_feat = s_total,
        #          sample_rate = SDparams['sample_rate'],
        #          filepath = basename+'.xml')
    mainxmlfile =  hybdatadict['donor_path'] + hybdatadict['donor']+'_afterprocessing.xml'   
    
    #os.system('ln -s %s %s.clu.1 ' %(trivialclufilename,basefilename))
    os.system('ln -s %s %s.spk.1 ' %(mainspkfile,basefilename))
    os.system('ln -s %s %s.res.1 ' %(mainresfile,basefilename))
    os.system('cp %s %s.xml ' %(mainxmlfile,basefilename))
    
    return basefilename
def make_KKfiles_Script(hybdatadict, SDparams,prb, detectioncrit, KKparams):
    '''Creates the files required to run KlustaKwik'''
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth' 
    
    KKhash = hash_utils.hash_dictionary_md5(KKparams)
    baselist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname'], KKhash]
    basefilename =  hash_utils.make_concatenated_filename(baselist)
    
    mainbasefilelist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname']]
    mainbasefilename = hash_utils.make_concatenated_filename(mainbasefilelist)
    
    DIRPATH = hybdatadict['output_path']
    os.chdir(DIRPATH)
    with Experiment(hash_hyb_SD, dir= DIRPATH, mode='r') as expt:
        if KKparams['numspikesKK'] is not None: 
            fets = expt.channel_groups[0].spikes.features[0:KKparams['numspikesKK']]
            fmasks = expt.channel_groups[0].spikes.features_masks[0:KKparams['numspikesKK'],:,1]
            
            masks = expt.channel_groups[0].spikes.masks[0:KKparams['numspikesKK']]
        else: 
            fets = expt.channel_groups[0].spikes.features[:]
            fmasks = expt.channel_groups[0].spikes.features_masks[:,:,1]
            #print fmasks[3,:]
            masks = expt.channel_groups[0].spikes.masks[:]
    
    mainfetfile = DIRPATH + mainbasefilename+'.fet.1'
    mainfmaskfile = DIRPATH + mainbasefilename+'.fmask.1'
    mainmaskfile = DIRPATH + mainbasefilename+'.mask.1'
    
    if not os.path.isfile(mainfetfile):
        write_fet(fets,mainfetfile )
    else: 
        print mainfetfile, ' already exists, moving on \n '
        
    if not os.path.isfile(mainfmaskfile):
        write_mask(fmasks,mainfmaskfile,fmt='%f')
    else: 
        print mainfmaskfile, ' already exists, moving on \n '  
    
    if not os.path.isfile(mainmaskfile):
        write_mask(masks,mainmaskfile,fmt='%f')
    else: 
        print mainmaskfile, ' already exists, moving on \n '    
        
    
    
    os.system('ln -s %s %s.fet.1 ' %(mainfetfile,basefilename))
    os.system('ln -s %s %s.fmask.1 ' %(mainfmaskfile,basefilename))
    os.system('ln -s %s %s.mask.1 ' %(mainmaskfile,basefilename))
    
    KKscriptname = basefilename
    make_KKscript(KKparams,basefilename,KKscriptname)
    
    return basefilename
def make_KKfiles_Script_full(hybdatadict, SDparams,prb, detectioncrit, KKparams):
    '''Creates the files required to run KlustaKwik'''
    argSD = [hybdatadict,SDparams,prb]
    if ju.is_cached(rsd.run_spikedetekt,*argSD):
        print 'Yes, SD has been run \n'
        hash_hyb_SD = rsd.run_spikedetekt(hybdatadict,SDparams,prb)
    else:
        print 'You need to run Spikedetekt before attempting to analyse results ' 
    
    
    argTD = [hybdatadict, SDparams,prb, detectioncrit]      
    if ju.is_cached(ds.test_detection_algorithm,*argTD):
        print 'Yes, you have run detection_statistics.test_detection_algorithm() \n'
        detcrit_groundtruth = ds.test_detection_algorithm(hybdatadict, SDparams,prb, detectioncrit)
    else:
        print 'You need to run detection_statistics.test_detection_algorithm() \n in order to obtain a groundtruth' 
    
    KKhash = hash_utils.hash_dictionary_md5(KKparams)
    baselist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname'], KKhash]
    basefilename =  hash_utils.make_concatenated_filename(baselist)
    
    mainbasefilelist = [hash_hyb_SD, detcrit_groundtruth['detection_hashname']]
    mainbasefilename = hash_utils.make_concatenated_filename(mainbasefilelist)
    
    DIRPATH = hybdatadict['output_path']
    os.chdir(DIRPATH)
    with Experiment(hash_hyb_SD, dir= DIRPATH, mode='r') as expt:
        if KKparams['numspikesKK'] is not None: 
            feats = expt.channel_groups[0].spikes.features[0:KKparams['numspikesKK']]
            prefmasks = expt.channel_groups[0].spikes.features_masks[0:KKparams['numspikesKK'],:,1]
            
            premasks = expt.channel_groups[0].spikes.masks[0:KKparams['numspikesKK']]
            res = expt.channel_groups[0].spikes.time_samples[0:KKparams['numspikesKK']]
        else: 
            feats = expt.channel_groups[0].spikes.features[:]
            prefmasks = expt.channel_groups[0].spikes.features_masks[:,:,1]
            #print fmasks[3,:]
            premasks = expt.channel_groups[0].spikes.masks[:]
            res = expt.channel_groups[0].spikes.time_samples[:]    
            
        mainresfile = DIRPATH + mainbasefilename + '.res.1' 
        mainspkfile = DIRPATH + mainbasefilename + '.spk.1'
        detcritclufilename = DIRPATH + mainbasefilename + '.detcrit.clu.1'
        trivialclufilename = DIRPATH + mainbasefilename + '.clu.1'
        
        #arg_spkresdetclu = [expt,res,mainresfile, mainspkfile, detcritclufilename, trivialclufilename]
        #if ju.is_cached(make_spkresdetclu_files,*arg_spkresdetclu):
        if os.path.isfile(mainspkfile):
            print 'miscellaneous files probably already exist, moving on, saving time'
        else:
            make_spkresdetclu_files(expt,res,mainresfile, mainspkfile, detcritclufilename, trivialclufilename) 
        
        #write_res(res,mainresfile)
        #write_trivial_clu(res,trivialclufilename)
        #write_spk_buffered(expt.channel_groups[0].spikes.waveforms_filtered,
        #                    mainspkfile,
        #                   np.arange(len(res)))
        #write_clu(detcrit_groundtruth['detected_groundtruth'], detcritclufilename)
        
        times = np.expand_dims(res, axis =1)
        masktimezeros = np.zeros_like(times)
        fets = np.concatenate((feats, times),axis = 1)
        fmasks = np.concatenate((prefmasks, masktimezeros),axis = 1)
        masks = np.concatenate((premasks, masktimezeros),axis = 1)
    
    mainfetfile = DIRPATH + mainbasefilename+'.fet.1'
    mainfmaskfile = DIRPATH + mainbasefilename+'.fmask.1'
    mainmaskfile = DIRPATH + mainbasefilename+'.mask.1'
    
    #print fets
    #embed()
    
    if not os.path.isfile(mainfetfile):
        write_fet(fets,mainfetfile )
    else: 
        print mainfetfile, ' already exists, moving on \n '
        
    if not os.path.isfile(mainfmaskfile):
        write_mask(fmasks,mainfmaskfile,fmt='%f')
    else: 
        print mainfmaskfile, ' already exists, moving on \n '  
    
    if not os.path.isfile(mainmaskfile):
        write_mask(masks,mainmaskfile,fmt='%f')
    else: 
        print mainmaskfile, ' already exists, moving on \n '    
        
    
    mainxmlfile =  hybdatadict['donor_path'] + hybdatadict['donor']+'_afterprocessing.xml'   
    os.system('ln -s %s %s.fet.1 ' %(mainfetfile,basefilename))
    os.system('ln -s %s %s.fmask.1 ' %(mainfmaskfile,basefilename))
    os.system('ln -s %s %s.mask.1 ' %(mainmaskfile,basefilename))
    os.system('ln -s %s %s.trivial.clu.1 ' %(trivialclufilename,basefilename))
    os.system('ln -s %s %s.spk.1 ' %(mainspkfile,basefilename))
    os.system('ln -s %s %s.res.1 ' %(mainresfile,basefilename))
    os.system('cp %s %s.xml ' %(mainxmlfile,mainbasefilename))
    os.system('cp %s %s.xml ' %(mainxmlfile,basefilename))
    
    KKscriptname = basefilename
    make_KKscript(KKparams,basefilename,KKscriptname)
    
    return basefilename