def Validation(model, optimizer, lossfunction, targetvar, train_graphs, val_graphs, train_queue, val_queue, val_queue_fill, train_queue_fill, workers, batch_size, val_batches, epoch, epochs): deadcheck = False traincheck = False val_loss = torch.tensor([0], dtype=float).to(device) for k in range(val_batches): with torch.no_grad(): batch = Processing.GrabBatch(val_queue, device) output = model(batch)[:, 0] loss = lossfunction(output, batch.y[:, targetvar]) val_loss += loss deadcheck = Processing.Process_check(val_queue_fill, deadcheck, k, val_batches) if deadcheck == True and traincheck == False and epoch < epochs - 1: train_queue_fill = Processing.Spawn_Processes( workers, train_graphs, train_queue, batch_size) traincheck = True deadcheck = Processing.Process_check(val_queue_fill, deadcheck, k, val_batches) if deadcheck == False: print('Validation Processes Still Alive. Terminating...') for process in val_queue_fill: process.terminate() if epoch < epochs - 1: train_queue_fill = Processing.Spawn_Processes( workers, train_graphs, train_queue, batch_size) return val_loss, train_queue_fill
def Trainloop(model,optimizer,lossfunction,targetvar,device,train_graphs,val_graphs,workers,batch_size,train_batches,val_batches,epochs): average_train_loss_per_epoch = list() average_val_loss_per_epoch=list() for epoch in tqdm(range(epochs)): deadcheck=False valcheck=False model.train() train_loss = torch.tensor([0],dtype = float).to(device) if epoch==0: Manager=torch.multiprocessing.Manager() train_queue=Manager.Queue() val_queue=Manager.Queue() train_queue_fill=Processing.Spawn_Processes(workers, train_graphs, train_queue, batch_size) for k in range(train_batches): with torch.enable_grad(): model.train() batch=Processing.GrabBatch(train_queue,device) optimizer.zero_grad() output=model(batch) # print (output) # print (batch.x,output.size(),batch) # print(output.size()) loss=lossfunction(output,batch.y[:,targetvar]) # print (loss) loss.backward() optimizer.step() # print (loss,output,batch.y[:,targetvar]) deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,train_batches) train_loss +=loss if deadcheck==True and valcheck==False: val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size) valcheck=True # print (k) if(torch.sum(torch.isnan(output)) != 0): raise TypeError('NAN ENCOUNTERED AT : %s / %s'%(k,train_batches)) if( deadcheck == False): print('Training Processes Still Alive. Terminating...') for process in train_queue_fill: process.terminate() val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size) if deadcheck==True and valcheck==False: valcheck=True with torch.no_grad(): val_loss,train_queue_fill=Validation(model,optimizer,lossfunction,targetvar,train_graphs,val_graphs,train_queue,val_queue,val_queue_fill,train_queue_fill,workers,batch_size,val_batches,epoch,epochs) average_train_loss_per_epoch.append(train_loss.item()/(train_batches*batch_size)) average_val_loss_per_epoch.append(val_loss.item()/(val_batches*batch_size)) print (train_loss.item()/(train_batches*batch_size)) deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,train_batches) if( deadcheck == False): print('Training done. Terminating slaves...') for process in train_queue_fill: process.terminate() del batch,loss,train_loss,val_loss return model,average_train_loss_per_epoch,average_val_loss_per_epoch
def Predict(model,prediction_graphs,workers,pred_mini_batches,batch_size,currfolder,device,targetvar): # print('PREDICTING: \n \ # model : %s \n \ # n_events: %s' %(baseline,pred_mini_batches*batch_size)) predictions = [] truths = [] pred_events = [] manager = torch.multiprocessing.Manager() q = manager.Queue() slaves = Processing.Spawn_Processes(workers, prediction_graphs, q,batch_size) dead_check = False model.eval() with torch.no_grad(): for mini_batch in range(0,pred_mini_batches): data = Processing.GrabBatch(q,device) prediction = model(data) truth = Pidclass(data.y[:,targetvar]).unsqueeze(1).detach().cpu().numpy() pred_events.extend(data.event_no.detach().cpu().numpy()) predictions.extend(prediction.detach().cpu().numpy()) truths.extend(truth) dead_check =Processing.Process_check(slaves, dead_check, mini_batch, pred_mini_batches) #print (" predict batch ",mini_batch,"out of " ,pred_mini_batches) if( dead_check == False): for slave in slaves: slave.terminate() print('Saving results...') truths = pd.DataFrame(truths) predictions = pd.DataFrame(predictions) pred_events = pd.DataFrame(pred_events) result = pd.concat([pred_events,truths, predictions],axis = 1) result.columns = ['event_no','Pid','Antineutrino','Neutrino'] result.to_csv(currfolder + 'predictions.csv',index=False)
for epoch in tqdm(range(epochs)): train_queue_fill=Processing.Spawn_Processes(workers, train_graphs, train_queue, batch_size) deadcheck=False valcheck=False model.train() losstemp=np.zeros(valbatches) for k in range(minibatches): batch=Processing.GrabBatch(train_queue,device) optimizer.zero_grad() output=model(batch) loss=lossfunction(output,batch.y) loss.backward() optimizer.step() deadcheck=Processing.Process_check(train_queue_fill,deadcheck,k,minibatches) if deadcheck==True and valcheck==False: print (deadcheck) val_queue_fill=Processing.Spawn_Processes(workers,val_graphs,val_queue,batch_size) valcheck=True model.eval() for k in range(valbatches): batch=Processing.GrabBatch(val_queue,device) output=model(batch) loss=lossfunction(output,batch.y) losstemp[k]=loss.item() lossarr[epoch]=np.mean(losstemp) print ("done training") # while count<workers: # batch=GrabBatch(val_queue) # if batch=="done":