示例#1
0
def Validation(model, optimizer, lossfunction, targetvar, train_graphs,
               val_graphs, train_queue, val_queue, val_queue_fill,
               train_queue_fill, workers, batch_size, val_batches, epoch,
               epochs):
    deadcheck = False
    traincheck = False
    val_loss = torch.tensor([0], dtype=float).to(device)
    for k in range(val_batches):
        with torch.no_grad():
            batch = Processing.GrabBatch(val_queue, device)
            output = model(batch)[:, 0]
            loss = lossfunction(output, batch.y[:, targetvar])
        val_loss += loss
        deadcheck = Processing.Process_check(val_queue_fill, deadcheck, k,
                                             val_batches)
        if deadcheck == True and traincheck == False and epoch < epochs - 1:
            train_queue_fill = Processing.Spawn_Processes(
                workers, train_graphs, train_queue, batch_size)
            traincheck = True
    deadcheck = Processing.Process_check(val_queue_fill, deadcheck, k,
                                         val_batches)
    if deadcheck == False:
        print('Validation Processes Still Alive. Terminating...')
        for process in val_queue_fill:
            process.terminate()
        if epoch < epochs - 1:
            train_queue_fill = Processing.Spawn_Processes(
                workers, train_graphs, train_queue, batch_size)

    return val_loss, train_queue_fill
示例#2
0
def Trainloop(model,optimizer,lossfunction,targetvar,device,train_graphs,val_graphs,workers,batch_size,train_batches,val_batches,epochs):
    average_train_loss_per_epoch = list()
    average_val_loss_per_epoch=list()
    for epoch in tqdm(range(epochs)):

        deadcheck=False
        valcheck=False
        model.train()
        train_loss = torch.tensor([0],dtype = float).to(device)
        if epoch==0:
                Manager=torch.multiprocessing.Manager()
                train_queue=Manager.Queue()
                val_queue=Manager.Queue()
                train_queue_fill=Processing.Spawn_Processes(workers, train_graphs, train_queue, batch_size)
        for k in range(train_batches):
            with torch.enable_grad():
                model.train()
                batch=Processing.GrabBatch(train_queue,device)
                optimizer.zero_grad()
                output=model(batch)
                # print (output)
                # print (batch.x,output.size(),batch)
                # print(output.size())
                loss=lossfunction(output,batch.y[:,targetvar])
                # print (loss)
                loss.backward()
                optimizer.step()
                # print (loss,output,batch.y[:,targetvar])

            deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,train_batches)
            train_loss +=loss

            if deadcheck==True and valcheck==False:
                val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size)
                valcheck=True
            # print (k)
            if(torch.sum(torch.isnan(output)) != 0):
                raise TypeError('NAN ENCOUNTERED AT : %s / %s'%(k,train_batches))
        if( deadcheck == False):
                print('Training Processes Still Alive. Terminating...')
                for process in train_queue_fill:
                    process.terminate()
                val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size)
        if deadcheck==True and valcheck==False:

                valcheck=True
        with torch.no_grad():
            val_loss,train_queue_fill=Validation(model,optimizer,lossfunction,targetvar,train_graphs,val_graphs,train_queue,val_queue,val_queue_fill,train_queue_fill,workers,batch_size,val_batches,epoch,epochs)
        average_train_loss_per_epoch.append(train_loss.item()/(train_batches*batch_size))
        average_val_loss_per_epoch.append(val_loss.item()/(val_batches*batch_size))
        print (train_loss.item()/(train_batches*batch_size))
    deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,train_batches) 
    if( deadcheck == False):
        print('Training done. Terminating slaves...')
        for process in train_queue_fill:
            process.terminate()
    del batch,loss,train_loss,val_loss
    return model,average_train_loss_per_epoch,average_val_loss_per_epoch
示例#3
0
def Predict(model,prediction_graphs,workers,pred_mini_batches,batch_size,currfolder,device,targetvar):
    # print('PREDICTING: \n \
    #       model   : %s \n \
    #       n_events: %s' %(baseline,pred_mini_batches*batch_size))
    predictions     = []
    truths          = []
    pred_events     = []
    manager         = torch.multiprocessing.Manager()
    q               = manager.Queue()
    slaves          = Processing.Spawn_Processes(workers, prediction_graphs, q,batch_size)
    dead_check      = False
    model.eval()
    with torch.no_grad():
        for mini_batch in range(0,pred_mini_batches):
            data            = Processing.GrabBatch(q,device)
            prediction      = model(data)
            truth           = Pidclass(data.y[:,targetvar]).unsqueeze(1).detach().cpu().numpy()
            pred_events.extend(data.event_no.detach().cpu().numpy())
            predictions.extend(prediction.detach().cpu().numpy())
            truths.extend(truth)
            dead_check =Processing.Process_check(slaves, dead_check, mini_batch, pred_mini_batches)
            #print (" predict batch ",mini_batch,"out of " ,pred_mini_batches)
        if( dead_check == False):
            for slave in slaves:
                slave.terminate()
        print('Saving results...')
        truths          = pd.DataFrame(truths)
        predictions     = pd.DataFrame(predictions)
        pred_events     = pd.DataFrame(pred_events)
        result          = pd.concat([pred_events,truths, predictions],axis = 1)
        result.columns  = ['event_no','Pid','Antineutrino','Neutrino']
        result.to_csv(currfolder + 'predictions.csv',index=False)
示例#4
0
 for epoch in tqdm(range(epochs)):
     train_queue_fill=Processing.Spawn_Processes(workers, train_graphs, train_queue, batch_size)
     deadcheck=False
     valcheck=False
     model.train()
     losstemp=np.zeros(valbatches)
     for k in range(minibatches):
     
         batch=Processing.GrabBatch(train_queue,device)
         optimizer.zero_grad()
         output=model(batch)
         loss=lossfunction(output,batch.y)
         loss.backward()
         optimizer.step()
             
         deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,minibatches)
         if deadcheck==True and valcheck==False:
             print (deadcheck)
             val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size)
             valcheck=True
     model.eval()
     for k in range(valbatches):
         batch=Processing.GrabBatch(val_queue,device)
         output=model(batch)
         loss=lossfunction(output,batch.y)
         losstemp[k]=loss.item()
     lossarr[epoch]=np.mean(losstemp)
     print ("done training")
     # while count<workers:
     #     batch=GrabBatch(val_queue)
     #     if batch=="done":