Esempio n. 1
0
def create_outputs(x, feat, energy=None, n_ccoords=3, n_classes=6, td=TrainData_OC(), add_features=True):
    '''
    returns pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id
    '''
    
    feat = td.createFeatureDict(feat)
    
    pred_beta = Dense(1, activation='sigmoid')(x)
    pred_ccoords = Dense(n_ccoords,
                         #this initialisation is much better than standard glorot
                         kernel_initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./float(x.shape[-1])),
                         use_bias=False)(x) #bias has no effect
    
    pred_energy = ScalarMultiply(10.)(Dense(1)(x))
    if energy is not None:
        pred_energy = Multiply()([pred_energy,energy])
        
    pred_pos =  Dense(2)(x)
    pred_time = ScalarMultiply(10.)(Dense(1)(x))
    if add_features:
        pred_pos =  Add()([feat['recHitXY'],pred_pos])
        pred_time = Add()([feat['recHitTime'],pred_time])
    pred_id = Dense(n_classes, activation="softmax")(x)
    
    return pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id
Esempio n. 2
0
    def _make_plot(self, counter, feat, predicted, truth):#all these are lists and also include row splits
        try:
            td = TrainData_OC()#contains all dicts
            #row splits not needed
            feats = td.createFeatureDict(feat[0],addxycomb=False)
            backgather = predicted[self.use_backgather_idx]
            truths = td.createTruthDict(truth[0])
            
            data = {}
            data.update(feats)
            data.update(truths)
            
            if len(backgather.shape)<2:
                backgather = np.expand_dims(backgather,axis=1)
            
            data['recHitLogEnergy'] = np.log(data['recHitEnergy']+1)
            data['hitBackGatherIdx'] = backgather
            
            df = pd.DataFrame (np.concatenate([data[k] for k in data],axis=1), columns = [k for k in data])
            
            shuffle_truth_colors(df)
            
            fig = px.scatter_3d(df, x="recHitX", y="recHitZ", z="recHitY", 
                                color="truthHitAssignementIdx", size="recHitLogEnergy",
                                template='plotly_dark',
                    color_continuous_scale=px.colors.sequential.Rainbow)
            fig.update_traces(marker=dict(line=dict(width=0)))
            fig.write_html(self.outputfile + str(self.keep_counter) + "_truth.html")
            
            bgfile = self.outputfile + str(self.keep_counter) + "_backgather.html"
            #now the cluster indices
            
            shuffle_truth_colors(df,"hitBackGatherIdx")
            
            fig = px.scatter_3d(df, x="recHitX", y="recHitZ", z="recHitY", color="hitBackGatherIdx", size="recHitLogEnergy",
                                template='plotly_dark',
                    color_continuous_scale=px.colors.sequential.Rainbow)
            fig.update_traces(marker=dict(line=dict(width=0)))
            fig.write_html(bgfile)
            
            if self.publish is not None:
                publish(bgfile, self.publish)

        except Exception as e:
            print(e)
            raise e
Esempio n. 3
0
 def _make_plot(self, counter, feat, predicted, truth):
     try:
         td = TrainData_OC()#contains all dicts
         truths = td.createTruthDict(truth[0])
         feats = td.createFeatureDict(feat[0],addxycomb=False)
                                     
         data = {}
         data.update(truths)
         data.update(feats)
         data['recHitLogEnergy'] = np.log(data['recHitEnergy']+1)
         
         coords = predicted[self.use_prediction_idx]
         if not coords.shape[-1] == 3:
             print("plotGravNetCoordsDuringTraining only supports 3D coordinates") #2D and >3D TBI
             return #not supported
             
         data['coord A'] = coords[:,0:1]
         data['coord B'] = coords[:,1:2]
         data['coord C'] = coords[:,2:3]
         
         df = pd.DataFrame (np.concatenate([data[k] for k in data],axis=1), columns = [k for k in data])
         shuffle_truth_colors(df)
         
         fig = px.scatter_3d(df, x="coord A", y="coord B", z="coord C", 
                             color="truthHitAssignementIdx", size="recHitLogEnergy",
                             #hover_data=[],
                             template='plotly_dark',
                 color_continuous_scale=px.colors.sequential.Rainbow)
         fig.update_traces(marker=dict(line=dict(width=0)))
         ccfile = self.outputfile + str(self.keep_counter) + "_coords_"+ str(self.use_prediction_idx) +".html"
         fig.write_html(ccfile)
         
         
         if self.publish is not None:
             publish(ccfile, self.publish)
             
         
     except Exception as e:
         print(e)
         raise e
Esempio n. 4
0
# from tensorflow.keras.optimizer_v2 import Adam

from plotting_callbacks import plotEventDuringTraining
from ragged_callbacks import plotRunningPerformanceMetrics
from DeepJetCore.DJCLayers import StopGradient, ScalarMultiply, SelectFeatures, ReduceSumEntirely

from clr_callback import CyclicLR
from lossLayers import LLFullObjectCondensation, LLClusterCoordinates

from model_blocks import create_outputs

from Layers import LocalClusterReshapeFromNeighbours2, ManualCoordTransform, RaggedGlobalExchange, LocalDistanceScaling, CheckNaN, NeighbourApproxPCA, LocalClusterReshapeFromNeighbours, GraphClusterReshape, SortAndSelectNeighbours, LLLocalClusterCoordinates, DistanceWeightedMessagePassing, CollectNeighbourAverageAndMax, CreateGlobalIndices, LocalClustering, SelectFromIndices, MultiBackGather, KNN, MessagePassing
from datastructures import TrainData_OC

td = TrainData_OC()


def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
def analyse_one_file(_features, predictions, truth_in, soft=False):
    global num_visualized_segments, num_segments_to_visualize
    global dataset_analysis_dict, window_id

    # predictions = tf.constant(predictions[0])

    row_splits = _features[1][:, 0]

    features = _features[0]

    truth_idx = _features[2].astype(np.int32)
    truth_energy = _features[4]
    truth_position = _features[6]
    truth_time = _features[8]
    truth_pid = _features[10]

    pred_beta = predictions[0]
    pred_ccoords = predictions[1]
    pred_energy = predictions[2]
    pred_position = predictions[3]
    pred_time = predictions[4]
    pred_pid = predictions[5]

    _, row_splits = ragged_constructor((_features[10], row_splits))

    truth_all = truth_in[0]

    num_unique = []
    shower_sizes = []

    # here ..._s refers to quantities per window/segment
    #
    for i in range(len(row_splits) - 1):
        features_s = features[row_splits[i]:row_splits[i + 1]]
        truth_idx_s = truth_idx[row_splits[i]:row_splits[i + 1]]
        truth_energy_s = truth_energy[row_splits[i]:row_splits[i + 1]]
        truth_position_s = truth_position[row_splits[i]:row_splits[i + 1]]
        truth_time_s = truth_time[row_splits[i]:row_splits[i + 1]]
        truth_pid_s = truth_pid[row_splits[i]:row_splits[i + 1]]

        pred_beta_s = pred_beta[row_splits[i]:row_splits[i + 1]]
        pred_ccoords_s = pred_ccoords[row_splits[i]:row_splits[i + 1]]
        pred_energy_s = pred_energy[row_splits[i]:row_splits[i + 1]]
        pred_position_s = pred_position[row_splits[i]:row_splits[i + 1]]
        pred_time_s = pred_time[row_splits[i]:row_splits[i + 1]]
        pred_pid_s = pred_pid[row_splits[i]:row_splits[i + 1]]
        truth_s = truth_all[row_splits[i]:row_splits[i + 1]]

        td = TrainData_OC()  # contains all dicts
        analysis_input = dict()
        analysis_input["feat_all"] = td.createFeatureDict(features_s,
                                                          addxycomb=False)
        analysis_input["truth_sid"] = truth_idx_s
        analysis_input["truth_energy"] = truth_energy_s
        analysis_input["truth_position"] = truth_position_s
        analysis_input["truth_time"] = truth_time_s
        analysis_input["truth_pid"] = truth_pid_s

        analysis_input["truth_all"] = td.createTruthDict(truth_s)

        analysis_input["pred_beta"] = pred_beta_s
        analysis_input["pred_ccoords"] = pred_ccoords_s
        analysis_input["pred_energy"] = pred_energy_s
        analysis_input["pred_position"] = pred_position_s
        analysis_input["pred_time"] = pred_time_s
        analysis_input["pred_pid"] = pred_pid_s

        if num_visualized_segments < num_segments_to_visualize:
            window_analysis_dict = analyse_window_cut(analysis_input,
                                                      beta_threshold,
                                                      distance_threshold,
                                                      iou_threshold,
                                                      window_id,
                                                      True,
                                                      soft=soft)
        else:
            window_analysis_dict = analyse_window_cut(analysis_input,
                                                      beta_threshold,
                                                      distance_threshold,
                                                      iou_threshold,
                                                      window_id,
                                                      False,
                                                      soft=soft)

        append_window_dict_to_dataset_dict(dataset_analysis_dict,
                                           window_analysis_dict)
        num_visualized_segments += 1
        window_id += 1

    i += 1
Esempio n. 6
0
    def _make_plot(self, counter, feat, predicted, truth):

        try:
            '''
            [pred_beta, 
             pred_ccoords,
             pred_energy, 
             pred_pos, 
             pred_time, 
             pred_id
            '''
            td = TrainData_OC()#contains all dicts
            #row splits not needed
            feats = td.createFeatureDict(feat[0],addxycomb=False)
            truths = td.createTruthDict(truth[0])
            
            predBeta = predicted[0]
            
            print('>>>> plotting cluster coordinates... average beta',np.mean(predBeta), ' lowest beta ', 
                  np.min(predBeta), 'highest beta', np.max(predBeta))
            
            predCCoords = predicted[1]
            if not predCCoords.shape[-1] == 3:
                return #just for 3D ccoords
            
            #for later
            predEnergy = predicted[2]
            predX = predicted[3][:,0:1]
            predY = predicted[3][:,1:2]
            predT = predicted[4]
            
            data = {}
            data.update(feats)
            data.update(truths)
            
            
            data['recHitLogEnergy'] = np.log(data['recHitEnergy']+1)
            data['predBeta'] = predBeta
            data['predBeta+0.05'] = predBeta+0.05 #so that the others don't disappear
            data['predCCoordsX'] = predCCoords[:,0:1]
            data['predCCoordsY'] = predCCoords[:,1:2]
            data['predCCoordsZ'] = predCCoords[:,2:3]
            data['predEnergy'] = predEnergy
            data['predX']=predX
            data['predY']=predY
            data['predT']=predT
            data['(predBeta+0.05)**2'] = data['predBeta+0.05']**2
            data['(thresh(predBeta)+0.05))**2'] = np.where(predBeta>self.beta_threshold ,data['(predBeta+0.05)**2'], 0.)
            
            #for k in data:
            #    print(k, data[k].shape)
            
            df = pd.DataFrame (np.concatenate([data[k] for k in data],axis=1), columns = [k for k in data])
            
            #fig = px.scatter_3d(df, x="recHitX", y="recHitZ", z="recHitY", color="truthHitAssignementIdx", size="recHitLogEnergy")
            #fig.write_html(self.outputfile + str(self.keep_counter) + "_truth.html")
            shuffle_truth_colors(df)
            #now the cluster indices
            
            hover_data=['predBeta','predEnergy','truthHitAssignedEnergies',
                        'predT','truthHitAssignedT',
                        'predX', 'truthHitAssignedX',
                        'predY', 'truthHitAssignedY',
                        'truthHitAssignementIdx']
            
            fig = px.scatter_3d(df, x="predCCoordsX", y="predCCoordsY", z="predCCoordsZ", 
                                color="truthHitAssignementIdx", size="recHitLogEnergy",
                                hover_data=hover_data,
                                template='plotly_dark',
                    color_continuous_scale=px.colors.sequential.Rainbow)
            fig.update_traces(marker=dict(line=dict(width=0)))
            ccfile = self.outputfile + str(self.keep_counter) + "_ccoords.html"
            fig.write_html(ccfile)
            
            
            if self.publish is not None:
                publish(ccfile, self.publish)
            
            fig = px.scatter_3d(df, x="predCCoordsX", y="predCCoordsY", z="predCCoordsZ", 
                                color="truthHitAssignementIdx", size="(predBeta+0.05)**2",
                                hover_data=hover_data,
                                template='plotly_dark',
                    color_continuous_scale=px.colors.sequential.Rainbow)
            fig.update_traces(marker=dict(line=dict(width=0)))
            ccfile = self.outputfile + str(self.keep_counter) + "_ccoords_betasize.html"
            fig.write_html(ccfile)
            
            
            if self.publish is not None:
                publish(ccfile, self.publish)
                
            # thresholded
            fig = px.scatter_3d(df, x="predCCoordsX", y="predCCoordsY", z="predCCoordsZ", 
                                color="truthHitAssignementIdx", size="(thresh(predBeta)+0.05))**2",
                                hover_data=['predBeta','predEnergy', 'predX', 'predY', 'truthHitAssignementIdx', 
                                            'truthHitAssignedEnergies', 'truthHitAssignedX','truthHitAssignedY'],
                                template='plotly_dark',
                    color_continuous_scale=px.colors.sequential.Rainbow)
            fig.update_traces(marker=dict(line=dict(width=0)))
            ccfile = self.outputfile + str(self.keep_counter) + "_ccoords_betathresh.html"
            fig.write_html(ccfile)
            
            
            if self.publish is not None:
                publish(ccfile, self.publish)
            
            


        except Exception as e:
            print(e)
            raise e