def check_fitted(name_dir, user_choices):
    """
    Check object types on previously fitted sample.

    input: name_dir, str
           directory where fitted samples are stored.

           user_choices, dict
           output from snclass.util.read_user_input

    output: surv_names, dict
            keys are types and values are names of raw data files
    """
    # list classes surviving selection cuts
    surv = os.listdir(name_dir)

    surv_names = {}
    for name in surv:
        user_choices['path_to_lc'] = [translate_snid(name)[0]]

        try_lc = read_snana_lc(user_choices)
        stype = try_lc['SIM_NON1a:'][0]

        if try_lc['SIM_NON1a:'][0] not in surv_names.keys():
            surv_names[stype] = user_choices['path_to_lc']
        else:
            surv_names[stype].append(user_choices['path_to_lc'][0])

    return surv_names
def check_mean_GP(key, surv, draw, user_choices, synthetic_dir):
    """
    Check if there is a GP fit file.

    input: key, str
           object type

           surv, dict
           dictionary of objects surviving basic cuts
           keys are types, values are GP fit mean file name

           draw, dict
           dictionary of number of extra draws necessary for each type
           keys are types, values are number of draws

           user_choices, dict
           output from snclass.util.read_user_input

           synthetic_dir, str
           directory where synthetic sample is stored

    output: ready, list
            list of already existing files
    """
    ready = []
    
    # run through all types
    if key in surv.keys():
        for obj in surv[key]:

            # get object id
            obj_id = translate_snid(obj)

            # run through all types which shall be re-sampled,
            # check if they already exists
            for j in xrange(draw[key]):
                mean_file = synthetic_dir + '/' + \
                            user_choices['file_root'][0] + str(j) + \
                            'X' + obj_id + '_mean.dat'
                if os.path.isfile(mean_file) and mean_file not in ready:
                    ready.append(mean_file)
                    screen('Found ready SN ' + str(len(ready)) + 'X' + obj_id, user_choices)
    else:
        screen('type ' + str(key) + ' not present in surviving sample.')
    
    return ready
Beispiel #3
0
def classify_1obj(din):
    """
    Perform classification of 1 supernova.

    input: din, dict - keywords, value type: 
                     user_input, dict -> output from read_user_input
                     name, str -> name of raw light curve file
                     type_number, dict -> translate between str and numerical
                                          classes identification
                     do_plot, bool -> if True produce plots, default is False

                     p1, dict ->  keywords, value type:
                         fname_photo_list, str: list of all photometric 
                                                sample objects
                         photo_dir, str: directory of GP fitted results
                                         for photo sample
                         range_pcs, list: [min_number_PCs, max_number_PCs]
                                          to be tested through cross-validation
                         SNR_dir, str: directory to store all results from 
                                       this SNR cut
                         out_dir, str: directory to store classification 
                                       results
                         plot_proj_dir, str: directory to store 
                                             projection plots
                         data_matrix, str: file holding spec data matrix

    output: class_results:
               list -> [snid, true_type, prob_Ia] 
    """
    from snclass.functions import screen, nneighbor
    from snclass.util import translate_snid, read_snana_lc
    from snclass.treat_lc import LC

    # update supernova name
    din['user_input']['path_to_lc'] = [translate_snid(din['name'])[0]]

    # read raw data
    raw = read_snana_lc(din['user_input'])

    # set true type
    for names in din['type_number'].keys():
        if raw[din['user_input']['type_flag']
               [0]][0] in din['type_number'][names]:
            true_type = names

    # load GP fit and test epoch cuts
    new_lc = LC(raw, din['user_input'])
    new_lc.user_choices['samples_dir'] = [din['p1']['photo_dir']]
    new_lc.load_fit_GP(din['p1']['photo_dir'] + din['name'])

    l1 = [
        1 if len(new_lc.fitted['GP_fit'][fil]) > 0 else 0
        for fil in din['user_input']['filters']
    ]

    fil_choice = din['user_input']['ref_filter'][0]
    if fil_choice == 'None':
        fil_choice = None

    if sum(l1) == len(din['user_input']['filters']):
        new_lc.normalize(samples=True, ref_filter=fil_choice)
        new_lc.mjd_shift()
        new_lc.check_epoch()

        if new_lc.epoch_cuts:

            screen(new_lc.raw['SNID:'][0], din['user_input'])

            # build matrix lines
            new_lc.build_steps(samples=True)

            # transform samples
            small_matrix = new_lc.samples_for_matrix
            data_test = din['p1']['obj_kpca'].transform(small_matrix)

            #classify samples
            new_label = nneighbor(data_test, din['p1']['spec_matrix'],
                                  din['p1']['binary_types'], din['user_input'])

            # calculate final probability
            ntypes = [1 for item in new_label if item == '0']
            new_lc.prob_Ia = sum(ntypes) / \
                             float(din['user_input']['n_samples'][0])

            if din['do_plot']:
                plot_proj(din['p1']['spec_matrix'], data_test,
                          din['p1']['labels'], new_lc, din['p1']['plot_dir'],
                          [0, 1], true_type)

            # print result to screen
            screen('SN' + new_lc.raw['SNID:'][0] + \
                   ',   True type: ' + true_type + ', prob_Ia = ' + \
                    str(new_lc.prob_Ia), din['user_input'])

            class_results = [new_lc.raw['SNID:'][0], true_type, new_lc.prob_Ia]
            return class_results
Beispiel #4
0
def set_lclist(params):
    """
    Build a list of all objects satisfying selection cuts and plot them.

    input: params, dict
           keywords: plot_dir
                     path to store plots. If None do not build plots.
                     if None plots are not generated

                     fitted_data_dir
                     path to fitted data

                     list_dir
                     path to list directory 

                     sample
                     'spec' or 'photo'

                     user_choices, dict
                     output from snclass.read_user_input                
    """
    import numpy as np
    import pylab as plt
    import os

    from snclass.treat_lc import LC
    from snclass.util import translate_snid, read_snana_lc
    from snclass.functions import screen
    import sys

    # create plot directory
    if params['plot_dir'] is not None and \
    not os.path.isdir(params['plot_dir']):
        os.makedirs(params['plot_dir'])

    flist = os.listdir(params['fitted_data_dir'])

    photo_list = []
    problem = []
    cont = 0

    rfil = params['user_choices']['ref_filter'][0]

    for obj in flist:

        if 'mean' in obj and '~' not in obj and 'Y' not in obj:

            screen(obj, params['user_choices'])

            rname = translate_snid(obj)[0]
            params['user_choices']['path_to_lc'] = [rname]
            params['user_choices']['n_samples'] = ['0']

            raw = read_snana_lc(params['user_choices'])
            new_lc = LC(raw, params['user_choices'])

            if (params['user_choices']['file_root'][0] + raw['SNID:'][0] + \
               '_samples.dat' in flist):
                new_lc.user_choices['n_samples'] = ['100']
                new_lc.user_choices['samples_dir'] = [
                    params['fitted_data_dir']
                ]

                try:
                    new_lc.load_fit_GP(params['fitted_data_dir'] + obj)
                    l1 = [
                        1 if len(new_lc.fitted['GP_fit'][fil]) > 0 else 0
                        for fil in params['user_choices']['filters']
                    ]

                    if sum(l1) == len(params['user_choices']['filters']):
                        if rfil == 'None':
                            new_lc.normalize()
                        else:
                            new_lc.normalize(ref_filter=rfil)
                        new_lc.mjd_shift()
                        new_lc.check_epoch()

                        if new_lc.epoch_cuts:
                            photo_list.append(rname)

                            # only plot if not already done
                            if params['plot_dir'] is not None and \
                            not os.path.isfile(params['plot_dir'] + 'SN' + \
                                               raw['SNID:'][0] + '.png'):
                                new_lc.plot_fitted(file_out=\
                                                   params['plot_dir'] + \
                                                   'SN' + raw['SNID:'][0] + \
                                                   '.png')
                        else:
                            screen('SN' + raw['SNID:'][0] + ' did not satisfy' + \
                                   ' epoch cuts!\n', params['user_choices'])
                            cont = cont + 1
                    else:
                        screen('SN' + raw['SNID:'][0] + ' does not exist in ' + \
                               'all filters!\n', params['user_choices'])
                        cont = cont + 1

                except ValueError:
                    problem.append(rname)
                    cont = cont + 1

            else:
                screen('Samples not found for SN' + raw['SNID:'][0],
                       params['user_choices'])

        else:
            cont = cont + 1

    screen('Missed ' + str(cont) + ' SN.', params['user_choices'])

    # store list of problematic fits
    if len(problem) > 0:
        op2 = open('problematic_fits.dat', 'w')
        for obj in problem:
            op2.write(obj + '\n')
        op2.close()
        sys.exit()

    # set parameter for file name
    if int(params['user_choices']['epoch_cut'][0]) < 0:
        epoch_min = str(abs(int(params['user_choices']['epoch_cut'][0])))
    else:
        epoch_min = 'p' + \
                    str(abs(int(params['user_choices']['epoch_cut'][0])))

    epoch_max = str(int(params['user_choices']['epoch_cut'][1]) - 1)

    filter_list = params['user_choices']['filters'][0]
    for item in params['user_choices']['filters'][1:]:
        filter_list = filter_list + item

    # save objs list
    if not os.path.isdir(params['list_dir']):
        os.makedirs(params['list_dir'])

    ref_filter = params['user_choices']['ref_filter'][0]
    if ref_filter is None:
        ref_fils = 'global'
    else:
        ref_fils = ref_filter

    op1 = open(params['list_dir'] + params['sample'] + '_' + filter_list + \
               '_' + epoch_min + '_' + epoch_max + '_ref_' + ref_fils + \
               '.list', 'w')
    for item in photo_list:
        op1.write(item + '\n')
    op1.close()
Beispiel #5
0
def select_GP(params, user_choices):
    """
    Select original objs to build a synthetic spectroscopic sample.

    input: params, dict
           output from set_paramameters

           user_choices, dict
           output from snclass.util.read_user_input
    """
    from snclass.util import translate_snid, read_snana_lc
    from snclass.functions import screen
    from snclass.treat_lc import LC
    from snclass.fit_lc_gptools import save_result

    import os
    import numpy as np
    import sys

    # set reference filter
    if user_choices['ref_filter'][0] == 'None':
        fil_choice = None
    else:
        fil_choice = user_choices['ref_filter'][0]

    # select extra GP realizations in order to construct
    # a representative spec sample
    for key in params['draw_spec_samples'].keys():
        cont = 0
        fail = 0

        # check if there are existing objs in this sample
        screen('... Check existing objs', user_choices)
        ready = []
        for obj in params['surv_spec_names'][key]:
            obj_id = translate_snid(obj)

            for j in xrange(params['draw_spec_samples'][key]):
                mean_file = params['synthetic_dir'] + '/' + \
                            user_choices['file_root'][0] + str(j) + \
                            'X' + obj_id + '_mean.dat'

                if os.path.isfile(mean_file) and mean_file not in ready:
                    cont = cont + 1
                    ready.append(mean_file)
                    screen('Found ready SN ' + str(cont) + 'X' + \
                           obj_id, user_choices)

        while cont < params['draw_spec_samples'][key]:

            # draw one of the objs in the spec sample
            indx = np.random.randint(0, params['spec_pop'][key])
            name = params['surv_spec_names'][key][indx]

            user_choices['path_to_lc'] = [name]

            # read light curve raw data
            raw = read_snana_lc(user_choices)

            if os.path.isfile(params['fitted_data_dir'] + user_choices['file_root'][0] + \
                              raw['SNID:'][0] + '_samples.dat'):

                # initiate light curve object
                my_lc = LC(raw, user_choices)

                screen('Loading SN' + raw['SNID:'][0], user_choices)

                # load GP fit
                my_lc.user_choices['n_samples'] = ['100']
                my_lc.user_choices['samples_dir'] = [params['fitted_data_dir']]
                my_lc.load_fit_GP(params['fitted_data_dir'] + user_choices['file_root'][0] + \
                                  raw['SNID:'][0] + '_mean.dat')

                l1 = [
                    1 if len(my_lc.fitted['GP_fit'][fil]) > 0 else 0
                    for fil in user_choices['filters']
                ]
                if sum(l1) == len(user_choices['filters']):

                    # normalize
                    my_lc.normalize(samples=True, ref_filter=fil_choice)

                    # shift to peak mjd
                    my_lc.mjd_shift()

                    # check epoch requirements
                    my_lc.check_epoch()

                    if my_lc.epoch_cuts:

                        screen('... Passed epoch cuts', user_choices)
                        screen('... ... This is SN type ' +  raw[user_choices['type_flag'][0]][0] + \
                               ' number ' + str(cont + 1) + ' of ' +
                               str(params['draw_spec_samples'][key]), user_choices)

                        # draw one realization
                        size = len(my_lc.fitted['realizations'][
                            user_choices['filters'][0]])
                        indx2 = np.random.randint(0, size)

                        for fil in user_choices['filters']:
                            print '... ... ... filter ' + fil

                            raw['GP_fit'][fil] = my_lc.fitted['realizations'][
                                fil][indx2]
                            raw['GP_std'][fil] = my_lc.fitted['GP_std'][fil]
                            raw['xarr'][fil] = my_lc.fitted['xarr'][fil]

                        # set new file root
                        raw['file_root'] = [user_choices['file_root'][0] + \
                                             str(cont) + 'X']
                        raw['samples_dir'] = [params['synthetic_dir'] + '/']
                        save_result(raw)

                        # check epoch for this realization
                        new_lc = LC(raw, user_choices)
                        new_lc.load_fit_GP(params['synthetic_dir'] + '/' + \
                                       user_choices['file_root'][0] + str(cont) + \
                                       'X' + raw['SNID:'][0] + '_mean.dat')
                        new_lc.normalize(ref_filter=fil_choice)
                        new_lc.mjd_shift()
                        new_lc.check_epoch()

                        if new_lc.epoch_cuts:
                            cont = cont + 1
                        else:
                            screen('Samples failed to pass epoch cuts!\n',
                                   user_choices)
                            os.remove(params['synthetic_dir'] + '/' +
                                      user_choices['file_root'][0] + str(cont) + \
                                  'X' + raw['SNID:'][0] + '_mean.dat')
                        print '\n'

                    else:
                        screen('Failed to pass epoch cuts!\n', user_choices)
                        fail = fail + 1

                    if fail > 10 * params['spec_pop'][key]:
                        cont = 100000
                        sys.exit()
Beispiel #6
0
def set_parameters(params):
    """
    Set extra sample parameters and copy raw files to new directory.

    input: params, dict
           keywords: 'sample_size' -> number of objs in synthetic sample
                     'photo_perc' -> output from photo_frac
                     'spec_pop' -> output from sample_pop in spectroscopic
                                   (training) set
                     'synthetic_dir' -> directory to store synthetic sample
                     'list_name' -> file holding list of all objs in 
                                    spectroscopic sample

                     'fitted_data_dir' -> directory with spectroscopic
                                          sample fitted with GP
                     'representation'->
                         if 'original' the number of objects are not changed
                         if 'balanced' final sample contain the same number
                            of objs from all types
                         if 'representative' final spec sample resambles the
                            proportions in photometric sample

    output: update params dict
            new keywords: 'spec_num', dict ->  number of objs expected in the
                                               synthetic sample per class
                           'draw_spec_samples', dict - > number of draws
                                                         in each class
    """
    import numpy as np
    import os
    import shutil

    from snclass.util import translate_snid

    # calculate sample size
    if params['representation'] == 'original' or params['sample_size'] is None:
        params['sample_size'] = params['spec_pop']['tot']

    #construct number of SN expected in spec sample
    params['spec_num'] = {}
    for item in params['photo_perc'].keys():
        params['spec_num'][item] = np.round(params['sample_size'] * \
                                            params['photo_perc'][item])

    #define number of SN to draw from spec sample
    params['draw_spec_samples'] = {}
    for item in params['spec_pop'].keys():
        if item is not 'tot':
            diff = params['spec_num'][item] - params['spec_pop'][item]
            params['draw_spec_samples'][item] = int(diff) if diff > 0 else 0

    #construct synthetic spec data directory
    if not os.path.isdir(params['synthetic_dir']):
        os.makedirs(params['synthetic_dir'])

    # copy all mean files to synthetic directory
    fsample = read_file(params['list_name'])
    for fname in fsample:
        new_name = params['user_choices']['file_root'][0] + \
                   translate_snid(fname[0]) + '_mean.dat'
        shutil.copy2(params['fitted_data_dir'] + new_name,
                     params['synthetic_dir'] + new_name)

    return params
Beispiel #7
0
def build_sample(params):
    """
    Build a directory holding all raw data passing selection cuts.

    input: params, dict
           keywords:  'raw_dir' -> new directory to be created
                      'photo_dir' -> photometric LC fitted with GP
                      'spec_dir' -> sectroscopic LC fitted with GP
                      'user_choices' -> output from 
                                        snclass.util.read_user_input
    """
    import shutil
    from snclass.util import read_user_input, read_snana_lc, translate_snid
    from snclass.treat_lc import LC
    from snclass.functions import screen

    # create data directory
    if not os.path.isdir(params['raw_dir']):
        os.makedirs(params['raw_dir'])

    # read fitted light curves
    photo_list = os.listdir(params['photo_dir'])
    spec_list = os.listdir(params['spec_dir'])

    # build filter list
    fil_list = params['user_choices']['filters'][0]
    for i in xrange(1, len(params['user_choices']['filters'])):
        fil_list = fil_list + params['user_choices']['filters'][i]

    for sn_set in [photo_list, spec_list]:
        for obj in sn_set:
            if 'samples' in obj and '~' not in obj and 'Y' not in obj:

                screen(obj, params['user_choices'])

                rname = translate_snid(obj)[0]
                params['user_choices']['path_to_lc'] = [rname]
                params['user_choices']['n_samples'] = ['0']

                # read raw data
                raw = read_snana_lc(params['user_choices'])
                new_lc = LC(raw, params['user_choices'])

                # load GP fit
                if sn_set == photo_list:
                    new_lc.load_fit_GP(photo_dir +
                                       params['user_choices']['file_root'][0] +
                                       raw['SNID:'][0] + '_mean.dat')
                else:
                    new_lc.load_fit_GP(spec_dir +
                                       params['user_choices']['file_root'][0] +
                                       raw['SNID:'][0] + '_mean.dat')

                l1 = [
                    1 if len(new_lc.fitted['GP_fit'][fil]) > 0 else 0
                    for fil in params['user_choices']['filters']
                ]

                if sum(l1) == len(params['user_choices']['filters']):
                    # treat light curve
                    new_lc.normalize(ref_filter= \
                                     params['user_choices']['ref_filter'][0])
                    new_lc.mjd_shift()
                    new_lc.check_basic()
                    new_lc.check_epoch()

                    # check epoch cuts
                    data_path = params['user_choices']['path_to_obs'][0]
                    if new_lc.epoch_cuts:
                        shutil.copy2(data_path + rname, raw_dir + rname)
                    else:
                        screen('... SN' + raw['SNID:'][0] + \
                               ' fail to pass epoch cuts!',
                               params['user_choices'])
Beispiel #8
0
def classify(p1, user_input, type_number, do_plot=False):
    """
    Classify all objects in photometric sample.

    input: p1, dict
           keywords, value type:
               fname_photo_list, str: list of all photometric sample objects
               photo_dir, str: directory of GP fitted results for photo sample
               range_pcs, list: [min_number_PCs, max_number_PCs] to be tested
                                through cross-validation
               SNR_dir, str: directory to store all results from this SNR cut
               out_dir, str: directory to store classification results
               plot_proj_dir, str: directory to store projection plots
               data_matrix, str: file holding spec data matrix

           user_input, dict
           output from snclass.util.read_user_input

           type_number, dict
           dictionary to translate types between raw data and final
           classification
           keywords -> final classificaton elements
           values -> identifiers in raw data

           do_plot, bool - optional
           if True, creature projection plot for all test objects
           default is False
    """
    from snclass.functions import screen
    from snclass.util import translate_snid, read_snana_lc
    from snclass.treat_lc import LC

    import os
    import sys
    import numpy as np
    from multiprocessing import Pool

    # read photometric sample
    photo_fname = read_file(p1['fname_photo_list'])
    photo_list = [user_input['file_root'][0] + translate_snid(item[0]) + \
                  '_mean.dat'
                  for item in photo_fname
                  if os.path.isfile(p1['photo_dir'] +
                                    user_input['file_root'][0] + translate_snid(item[0]) +
                                    '_samples.dat') and '~' not in item[0]]

    for npcs in xrange(p1['range_pcs'][0], p1['range_pcs'][1]):

        if int(user_input['epoch_cut'][0]) < 0:
            mjd_min = str(abs(int(user_input['epoch_cut'][0])))
        else:
            mjd_min = 'p' + str(abs(int(user_input['epoch_cut'][0])))

        if int(user_input['epoch_cut'][1]) < 0:
            mjd_max = 'm' + str(abs(int(user_input['epoch_cut'][1]) - 1))
        else:
            mjd_max = str(int(user_input['epoch_cut'][1]) - 1)

        fils = user_input['filters'][0]
        for item in user_input['filters'][1:]:
            fils = fils + item

        out_dir = p1['out_dir']
        plot_proj_dir = p1['plot_proj_dir']
        p1['out_dir'] = p1['out_dir'] + str(npcs) + 'PC/'
        p1['cv_file'] = p1['out_dir'] + 'hyperpar_values.dat'

        if plot_proj_dir is not None:
            p1['plot_proj_dir'] = p1['plot_proj_dir'] + str(npcs) + 'PC/'
            if not os.path.isdir(p1['plot_proj_dir']):
                os.makedirs(p1['plot_proj_dir'])

        if not os.path.isdir(p1['out_dir']):
            os.makedirs(p1['out_dir'])

        p1['pars'], p1['alphas'] = read_hyperpar(p1)
        p1['data'], p1['sntype'], p1['binary_types'] = \
            read_matrix(p1['data_matrix'], Ia_codes=type_number['Ia'])
        p1['obj_kpca'], p1['spec_matrix'], p1['labels'] = \
            set_kpca_obj(p1['pars'], p1['data'], p1['sntype'], type_number)

        pars = []
        for name in photo_list:
            ptemp = {}
            ptemp['p1'] = p1
            ptemp['name'] = name
            ptemp['type_number'] = type_number
            ptemp['user_input'] = user_input
            ptemp['do_plot'] = do_plot

            pars.append(ptemp)

        if int(user_input['n_proc'][0]) > 1:
            pool = Pool(processes=int(user_input['n_proc'][0]))
            my_pool = pool.map_async(classify_1obj, pars)
            try:
                results = my_pool.get(0xFFFF)
            except KeyboardInterrupt:
                print 'Interruputed by the user!'
                sys.exit()

            pool.close()
            pool.join()

        else:
            results = []
            for element in pars:
                results.append(classify_1obj(element))

        p1['out_dir'] = out_dir
        p1['plot_proj_dir'] = plot_proj_dir

        if not os.path.isdir(p1['out_dir'] + str() + 'PC/'):
            os.makedirs(p1['out_dir'] + str() + 'PC/')

        op2 = open(
            p1['out_dir'] + str(npcs) + 'PC/class_res_' + str(npcs) + 'PC.dat',
            'w')
        op2.write('SNID    true_type    prob_Ia\n')
        for line in results:
            for item in line:
                op2.write(str(item) + '    ')
            op2.write('\n')
        op2.close()
Beispiel #9
0
def classify_test(test_name,
                  matrix,
                  user_input,
                  test_dir='test_samples/',
                  csamples=True):
    """
    Classify one photometric supernova using a trained KernelPCA matrix.

    input: test_name, str
           name of mean GP fit file

           matrix, snclass.matrix.DataMatrix object
           trained KernelPCA matrix

           user_input, dict
           output from snclass.util.read_user_input

           test_dir, str, optional
           name of directory to store samples from test light curve
           Default is 'test_samples/'

           csamples, bool, optional
           If True, fit GP object and generate sample file as output
           otherwise reads samples from file
           Default is True

    return: new_lc, snclass.treat_lc.LC object
            updated with test projections and probability of being Ia
    """
    # update path to raw light curve
    user_input['path_to_lc'] = [translate_snid(test_name, 'FLUXCAL')[0]]

    # store number of samples for latter tests
    nsamples = user_input['n_samples'][0]

    # reset the number of samples for preliminary tests
    user_input['n_samples'] = ['0']

    # read raw data
    raw = read_snana_lc(user_input)

    # load GP fit and test epoch cuts
    new_lc = LC(raw, user_input)
    new_lc.load_fit_GP(user_input['samples_dir'][0] + test_name)
    new_lc.normalize()
    new_lc.mjd_shift()
    new_lc.check_epoch()

    if new_lc.epoch_cuts:
        # update test sample directory
        user_input['samples_dir'] = [test_dir]

        # update user choices
        new_lc.user_choices = user_input

        # update number of samples
        new_lc.user_choices['n_samples'] = [nsamples]

        # fit GP or normalize/shift fitted mean
        test_matrix = test_samples(new_lc, calc_samples=bool(csamples))

        # project test
        new_lc.test_proj = matrix.transf_test.transform(test_matrix)

        # classify
        new_lc.new_label = nneighbor(new_lc.test_proj, matrix.low_dim_matrix,
                                     matrix.sntype, matrix.user_choices)

        if csamples:
            new_lc.prob_Ia = sum([1 for item in new_label if item == '0'
                                  ]) / float(nsamples)

        return new_lc

    else:
        return None
Beispiel #10
0
    def check_file(self, filename, epoch=True, ref_filter=None):
        """
        Construct one line of the data matrix.

        input:   filename, str
                 file of raw data for 1 supernova

                 epoch, bool - optional
                 If true, check if SN satisfies epoch cuts
                 Default is True

                 ref_filter, str - optional
                 Reference filter for peak MJD calculation
                 Default is None
        """
        screen('Fitting ' + filename, self.user_choices)

        # translate identifier
        self.user_choices['path_to_lc'] = [
            translate_snid(filename, self.user_choices['photon_flag'][0])[0]
        ]

        # read light curve raw data
        raw = read_snana_lc(self.user_choices)

        # initiate light curve object
        lc_obj = LC(raw, self.user_choices)

        # load GP fit
        lc_obj.load_fit_GP(self.user_choices['samples_dir'][0] + filename)

        # normalize
        lc_obj.normalize(ref_filter=ref_filter)

        # shift to peak mjd
        lc_obj.mjd_shift()

        if epoch:
            # check epoch requirements
            lc_obj.check_epoch()
        else:
            lc_obj.epoch_cuts = True

        if lc_obj.epoch_cuts:
            # build data matrix lines
            lc_obj.build_steps()

            # store
            obj_line = []
            for fil in self.user_choices['filters']:
                for item in lc_obj.flux_for_matrix[fil]:
                    obj_line.append(item)

            rflag = self.user_choices['redshift_flag'][0]
            redshift = raw[rflag][0]

            obj_class = raw[self.user_choices['type_flag'][0]][0]

            self.snid.append(raw['SNID:'][0])

            return obj_line, redshift, obj_class

        else:
            screen('... Failed to pass epoch cuts!', self.user_choices)
            screen('\n', self.user_choices)
            return None