def surv_data_2_ref(surv_data):
        
    # instantiate an empty reference
    ref = Reference()
    
    # add meta data to reference
    # at present just a timestamp
    from time import gmtime, strftime
    ref.set_meta({'timestamp':strftime("%Y-%m-%d %H:%M:%S", gmtime())})
    
    for channel_code in objectives_channel_codes:
        
        # add reference objective corresponding to each model objective
        if channel_code == 'prevalence':
            #debug_p('surv_data ' + str(surv_data['prevalence']))
            d_points = prevalence_surv_data_2_d_points(surv_data['prevalence'])
        else:
            msg = "Channel " + channel_code + " not implemeneted yet!\nSetting reference data to None."
            warn_p(msg)
            d_points = None
        
        #debug_p('adding objective ' + channel_code)
        #debug_p('num d_points ' + str(len(d_points)))
        
        ref.add_objective(channel_code, d_points)
        
    return ref
    def get_ref(self, cluster_id):

        surv_data = {}
        all_ref_objs_found = True
        for channel_code in objectives_channel_codes:
            if channel_code == 'prevalence':
                prev_data = c2p(cluster_id)
                if prev_data:
                    surv_data[channel_code] = prev_data
                else:
                    msg = 'Prevalence objective reference data was not found!\n Skipping plotting cluster ' + cluster_id + ' fit!'
                    print msg
                    all_ref_objs_found = False
            else:
                msg = "Channel objective" + channel_code + " not implemented yet!\nSetting objective reference data to None for plotting."
                warn_p(msg)
                surv_data[channel_code] = None
        
        # one of the reference objective channels was not found; skipping cluster fit!
        if not all_ref_objs_found:
            ref = None      
        else:  
            ref = d2f(surv_data)
            
        return ref       
def sim_report_channels_model_format(reports_channels, sim_data):
    
    report_channels_data = {}
    for report_channel in reports_channels:
        if report_channel == 'reinfections':
           report_channels_data['reinfections'] = get_sim_report_reinfections(sim_data)
        else:
            msg = "Channel " + report_channel + " not implemeneted yet!\nSetting report channels data to None."
            warn_p(msg)
            report_channels_data[report_channel] = None
    
    return report_channels_data   
示例#4
0
 def fit(self):
     
     best_fit = {}
     
     if self.type == 'mmse_distance':
         best_fit = self.best_fit_mmse_distance()
     elif self.type == 'mmse_distance_cached':
         best_fit = self.best_fit_mmse_distance(cached = True)
     else:
         msg = "unrecognized fit function type " + self.type + "!\nReturning None."
         warn_p(msg)
         return None
     
     return best_fit
def prevalence_surv_data_2_d_points(prev_surv_data):
    
    if prev_surv_data:
        prev_ts = len(prev_surv_data)*[0]
        
        for rnd,prev in prev_surv_data.iteritems():
            if not prev == -1000:
                prev_ts[int(rnd)] = prev;
            else:
                # if no data in reference impute to 'nan'
                prev_ts[int(rnd) ] = 'nan'    
        
        return prev_ts
    else:
        msg = 'no reference data for prevalence available!\n Returning none.'
        warn_p(msg)
        return None
 def generate_gazetteer_map(self):
     
     gazetteer_base_map_file_path = os.path.join(ref_data_dir, gazetteer_base_map_file_name)
     
     with open(gazetteer_base_map_file_path ,'r') as map_f:
         self.base_map = json.load(map_f)
         
     self.sweep_map = []
     for cluster_id, cluster_record in self.best_fits.iteritems():
         cluster_map_record = self.get_cluster_map_record(cluster_id)
         
         if cluster_map_record:
             #cluster_map_record['sim_avg_reinfection_rate'] = cluster_record['sim_avg_reinfection_rate']
             #cluster_map_record['ref_avg_reinfection_rate'] = cluster_record['ref_avg_reinfection_rate']
             cluster_map_record['temp_h'] = cluster_record['habs']['temp_h']
             cluster_map_record['const_h'] = cluster_record['habs']['const_h']
             
             mn_sim_map_file_path = os.path.join(kariba_viz_dir,  gazetteer_sim_mn_base_map_file_name)
             if os.path.exists(mn_sim_map_file_path):
                 with open(mn_sim_map_file_path ,'r') as mn_map_f:
                     mn_map = json.load(mn_map_f)
                     cluster_map_record['RDT_mn_sim'] = mn_map[cluster_id]['RDT_sim']
             
             cluster_map_record['itn_level'] = cluster_record['ITN_cov']
             cluster_map_record['drug_coverage'] = cluster_record['MSAT_cov']
             cluster_map_record['fit_value'] = cluster_record['fit_value']
             cluster_map_record['RDT_sn_sim'] = cluster_record['prevalence']
             
             self.sweep_map.append(cluster_map_record)
              
         else:
             warn_p('No cluster map record found in base map. Skipping generating sim entries in cluster map record.')
             
         
     gazetteer_sweep_map_file_path = os.path.join(kariba_viz_dir, self.sweep_name + '_' + gazetteer_base_map_file_name)
     
     with open(gazetteer_sweep_map_file_path ,'w') as map_f:
         json.dump(self.sweep_map, map_f, indent = 4)
示例#7
0
 def mse(self, m_points, d_points, points_weights):
     
     if d_points:
         
         num_obs = len(d_points)
         
         if not len(m_points) == num_obs:
             msg = "number of points in model does not match num of points in reference data!\nReturning None."
             warn_p(msg)
             return None  
         
         if num_obs == 0:
             msg = "no observations provided!\nReturning None."
             warn_p(msg)
             return None
         
         mse = 0.0
         non_nan_obs = 0
         for idx, m_p in enumerate(m_points):
             d_p = d_points[idx]
             if not (d_p == 'nan' or m_p == 'nan'): 
                 mse = mse + math.pow(m_p - d_p, 2)*points_weights[idx]
                 non_nan_obs = non_nan_obs + 1
         if non_nan_obs > 0:
             mse = mse / non_nan_obs
         else:
             msg = "only nan observations provided!\nReturning None"
             warn_p(msg)
             return None
         
         return mse
     
     
     msg = "no reference data provided for fit!\nReturning None"
     warn_p(msg)
     
     return None
    def fit(self):
        
        models_list_prime = calib_data_2_models_list(self.calib_data)
                
        best_fits = {}
        all_fits = {}
        #all_fits = {'fit':{'min_residual':float('inf')}, }
        
        all_fits['min_residual'] = float('inf')
        all_fits['max_residual'] = 0.0
        
        
        
        all_fits['models'] = {}
        
        debug_p('category ' + self.category)
        
        for idx,cluster_id in enumerate(c2c(self.category)):
        
            models_list = copy.deepcopy(models_list_prime)
            
            print "Processing cluster " + cluster_id + "."
            debug_p('Processing cluster ' + cluster_id + " in " + self.category + ".")
            
            itn_traj = cluster_2_itn_traj(cluster_id)
            drug_cov = cluster_2_drug_cov(cluster_id)
            
            # prune models to the ones matching prior data
            cluster_models = []
            for model in models_list:
                model_meta = model.get_meta()
                if model_meta['group_key'] == get_sim_group_key(itn_traj, drug_cov):
                    #debug_p('model id before kariba conversion ' + str(model.get_model_id()))
                    group_key = model_meta['group_key']
                    sim_key = model_meta['sim_key']

                    model = KaribaModel(model, self.calib_data[group_key][sim_key], cluster_id, all_fits = self.fit_terms)
                    
                    #model = kariba_model
                    #debug_p('model id after kariba conversion ' + str(model.get_model_id()))
                    cluster_models.append(model)
                
            surv_data = {}
            all_ref_objs_found = True
            for channel_code in objectives_channel_codes:
                if channel_code == 'prevalence':
                    prev_data = c2p(cluster_id)
                    if prev_data:
                        surv_data[channel_code] = prev_data
                    else:
                        msg = 'Prevalence objective reference data was not found!\n Skipping cluster ' + cluster_id + ' fit!'
                        print msg
                        all_ref_objs_found = False
                else:
                    msg = "Channel objective" + channel_code + " not implemented yet!\nSetting objective reference data to None."
                    warn_p(msg)
                    surv_data[channel_code] = None
            
            # one of the reference objective channels was not found; skipping cluster fit!
            if not all_ref_objs_found:
                continue
                        
            ref = d2f(surv_data)
            
            # adjust highest possible fit to account for RDT+ model in dtk not reflecting reality at the upper end
            obj_prev = ref.get_obj_by_name('prevalence')
            d_points = obj_prev.get_points()
            obj_prev.set_points([min(point, rdt_max) for point in d_points])
            
            
            fitting_set = FittingSet(cluster_id, cluster_models, ref)
            
            if load_prevalence_mse:
                fit = Fit(fitting_set, type = 'mmse_distance_cached')
            else:
                fit = Fit(fitting_set)
            
            best_fit_model = fit.best_fit_mmse_distance()
            
            min_residual = fit.get_min_residual()
            max_residual = fit.get_max_residual()
            
            if min_residual  < all_fits['min_residual']:
                all_fits['min_residual'] = min_residual 
                
            if max_residual  > all_fits['max_residual']:
                all_fits['max_residual'] = max_residual
            
            if best_fit_model: 
            
                temp_h, const_h, itn_level, drug_coverage_level = get_model_params(best_fit_model)
                best_fit_meta = best_fit_model.get_meta()
                best_fits[cluster_id] = {}
                best_fits[cluster_id]['habs'] = {}
                best_fits[cluster_id]['habs']['const_h'] = const_h 
                best_fits[cluster_id]['habs']['temp_h'] = temp_h
                best_fits[cluster_id]['ITN_cov'] = itn_level
                best_fits[cluster_id]['category'] = self.category
                best_fits[cluster_id]['MSAT_cov'] = drug_coverage_level
                best_fits[cluster_id]['sim_id'] = best_fit_meta['sim_id']
                best_fits[cluster_id]['sim_key'] = best_fit_meta['sim_key'] 
                best_fits[cluster_id]['group_key'] = best_fit_meta['group_key']
                best_fits[cluster_id]['fit_value'] = best_fit_model.get_fit_val()
                best_fits[cluster_id]['sim_avg_reinfection_rate'] = best_fit_model.get_sim_avg_reinfection_rate()
                best_fits[cluster_id]['ref_avg_reinfection_rate'] = best_fit_model.get_ref_avg_reinfection_rate()
                best_fits[cluster_id]['prevalence'] = best_fit_model.get_objective_by_name('prevalence').get_points()
            
                # redundancy; to be refactored via FitEntry class                
                best_fits[cluster_id]['fit'] = {}
                best_fits[cluster_id]['fit']['value'] = best_fit_model.get_fit_val()
                best_fits[cluster_id]['fit']['temp_h'] = temp_h
                best_fits[cluster_id]['fit']['const_h'] = const_h
                best_fits[cluster_id]['fit']['ITN_cov'] = itn_level
                best_fits[cluster_id]['fit']['MSAT_cov'] = drug_coverage_level
                best_fits[cluster_id]['fit']['sim_id'] = best_fit_meta['sim_id']
                best_fits[cluster_id]['fit']['sim_key'] = best_fit_meta['sim_key']
                
                
                best_fits[cluster_id]['mse'] = {}
                best_fits[cluster_id]['mse']['value'] = fit.get_min_mses()['prevalence']['value'] # get mmse for objective prevalence
                best_fit_mse_model = fit.get_min_mses()['prevalence']['model']
                temp_h, const_h, itn_level, drug_coverage_level = get_model_params(best_fit_mse_model)
                model_meta_data = best_fit_mse_model.get_meta()
                best_fits[cluster_id]['mse']['temp_h'] = temp_h
                best_fits[cluster_id]['mse']['const_h'] = const_h
                best_fits[cluster_id]['mse']['ITN_cov'] = itn_level
                best_fits[cluster_id]['mse']['MSAT_cov'] = drug_coverage_level
                best_fits[cluster_id]['mse']['sim_id'] = model_meta_data['sim_id']
                best_fits[cluster_id]['mse']['sim_key'] = model_meta_data['sim_key']
                
                best_fits[cluster_id]['cc_penalty'] = {}
                best_fits[cluster_id]['cc_penalty']['value'] = fit.get_min_penalties()['prevalence']['value'] # get clinical penalty for objective prevalence; at present this is just the clinical cases penalty; if reinfection is considered the code needs to be adjusted
                best_fit_cc_penalty_model = fit.get_min_penalties()['prevalence']['model']
                temp_h, const_h, itn_level, drug_coverage_level = get_model_params(best_fit_cc_penalty_model)
                model_meta_data = best_fit_cc_penalty_model.get_meta()
                best_fits[cluster_id]['cc_penalty']['temp_h'] = temp_h
                best_fits[cluster_id]['cc_penalty']['const_h'] = const_h
                best_fits[cluster_id]['cc_penalty']['ITN_cov'] = itn_level
                best_fits[cluster_id]['cc_penalty']['MSAT_cov'] = drug_coverage_level
                best_fits[cluster_id]['cc_penalty']['sim_id'] = model_meta_data['sim_id']
                best_fits[cluster_id]['cc_penalty']['sim_key'] = model_meta_data['sim_key']
                  
    
                rho = best_fit_model.get_rho()
                p_val = best_fit_model.get_p_val()
                
                if rho and p_val :
                    best_fits[cluster_id]['rho'] = rho
                    best_fits[cluster_id]['p_val'] = p_val
                    
                    debug_p('rho' + str(rho))
                    debug_p('p_val' + str(p_val)) 
                
                
            else:
                msg = "something went wrong and the best fit for " + cluster_id + " could not be found."
                warn_p(msg)
                
            
            all_fits['models'][cluster_id] = cluster_models
            #all_fits['models'][cluster_id] = fit.get_fitting_set_models()
            
            print str(idx+1) + " clusters have been processed."
            debug_p( str(idx+1) + " clusters have been processed in category " + self.category)
            
            '''
            if idx > 0:
                break 
            '''      
        return best_fits, all_fits