def train(args): datasets = [i for i in range(5)] # Remove the leave out dataset from the datasets datasets.remove(args.leaveDataset) # datasets = [0] # args.leaveDataset = 0 # Construct the DataLoader object dataloader = DataLoader(args.batch_size, args.seq_length + 1, datasets, forcePreProcess=True) # Construct the ST-graph object stgraph = ST_GRAPH(1, args.seq_length + 1) # Log directory log_directory = 'log_w_goal/' log_directory += str(args.leaveDataset) + '/' log_directory += 'log_attention' # Logging file log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w') log_file = open(os.path.join(log_directory, 'val.txt'), 'w') # Save directory save_directory = 'save_w_goal/' save_directory += str(args.leaveDataset) + '/' save_directory += 'save_attention' # Open the configuration file with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, 'srnn_model_' + str(x) + '.tar') # Initialize net net = SRNN(args) net.cuda() optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True) learning_rate = args.learning_rate print('Training begin') best_val_loss = 100 best_epoch = 0 # Training for epoch in range(args.num_epochs): dataloader.reset_batch_pointer(valid=False) loss_epoch = 0 # For each batch for batch in range(dataloader.num_batches): start = time.time() # Get batch data x, _, _, d = dataloader.next_batch(randomUpdate=True) import pdb pdb.set_trace() # Loss for this batch loss_batch = 0 # For each sequence in the batch for sequence in range(dataloader.batch_size): # Construct the graph for the current sequence stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() # Zero out the gradients net.zero_grad() optimizer.zero_grad() # Forward prop outputs, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Compute gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset the stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'. format(epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start)) # Compute loss for the entire epoch loss_epoch /= dataloader.num_batches # Log it log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',') # Validation dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False) # Loss for this batch loss_batch = 0 for sequence in range(dataloader.batch_size): stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() outputs, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch # Record best epoch and best validation loss print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch)) print('Best epoch {}, Best validation loss {}'.format( best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch) + '\n') # Save the model after each epoch print('Saving model') torch.save( { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, checkpoint_path(epoch)) # Record the best epoch and best validation loss overall print('Best epoch {}, Best validation loss {}'.format( best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch) + ',' + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()
def train(args): datasets = [i for i in range(5)] # Remove the leave out dataset from the datasets datasets.remove(args.leaveDataset) # datasets = [0] # args.leaveDataset = 0 # # Construct the ST-graph object filedir = '../beijing/' grp = My_Graph(args.seq_length + 1, filedir + 'W.pk', filedir + 'traffic_data.csv') my_datapt = My_DataPointer(grp.getFrameNum(), args.seq_length + 1, args.batch_size) # Log directory log_directory = '../log-beijing-16/' if not os.path.exists(log_directory): os.makedirs(log_directory) # Logging file log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w') log_file = open(os.path.join(log_directory, 'val.txt'), 'w') # Save directory save_directory = '../save-beijing-16/' if not os.path.exists(save_directory): os.makedirs(save_directory) # Open the configuration file with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, 'srnn_model_' + str(x) + '.tar') # Initialize net net = SRNN(args) net.cuda() if (args.load >= 0): load_net(save_directory, args.load, net) optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True) learning_rate = args.learning_rate print('Training begin') best_val_loss = 100 best_epoch = 0 print(f"train_batches={my_datapt.num_batches()}") print(f"val_batches={my_datapt.valid_num_batches()}") # Training for epoch in range(args.num_epochs): log_output = open(log_directory + 'log_output.py', "w") log_output.write("pairs=[\n") loss_epoch = 0 my_datapt.train_reset() # For each batch for batch in range(my_datapt.num_batches()): start = time.time() # Loss for this batch loss_batch = 0 # For each sequence in the batch for st in my_datapt.get_batch(): # Construct the graph for the current sequence t1 = time.time() nodes, edges, nodesPresent, edgesPresent = grp.getSequence(st) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() # Zero out the gradients net.zero_grad() optimizer.zero_grad() # print(f"grp time = {time.time()-t1}") # Forward prop outputs, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) log_output.write( f"[{outputs[-1,:,0].data.cpu().numpy().tolist()},{nodes[-1,:,0].data.cpu().numpy().tolist()}],\n" ) # print(f"forward time = {time.time()-t1}") # Compute loss loss = getL2Loss(outputs, nodes[1:], nodesPresent[1:], args.pred_length) print(f"start={st},loss={loss}") # print(f"loss time = {time.time()-t1}") loss_batch += loss.data # Compute gradients loss.backward() # print(f"backward time = {time.time()-t1}") # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # print(f"step time = {time.time()-t1}") end = time.time() loss_batch = loss_batch / args.batch_size loss_epoch += loss_batch print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'. format(epoch * my_datapt.num_batches() + batch, args.num_epochs * my_datapt.num_batches(), epoch, loss_batch, end - start)) # Compute loss for the entire epoch loss_epoch /= my_datapt.num_batches() # Log it log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',') log_output.write("]\n") # Save the model after each epoch print('Saving model') torch.save( { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, checkpoint_path(epoch)) # Validation my_datapt.val_reset() loss_epoch = 0 with torch.no_grad(): for batch in range(my_datapt.valid_num_batches()): # Loss for this batch loss_batch = 0 for start in my_datapt.get_batch_val(): nodes, edges, nodesPresent, edgesPresent = grp.getSequence( start) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() outputs, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) # Compute loss loss = getL2Loss(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data loss_batch = loss_batch / args.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / my_datapt.valid_num_batches() # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch # Record best epoch and best validation loss print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch)) print('Best epoch {}, Best validation loss {}'.format( best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch) + '\n') # Record the best epoch and best validation loss overall print('Best epoch {}, Best validation loss {}'.format( best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch) + ',' + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()
def train(args): ## 19th Feb : move training files for 3 && 4 to fine_obstacle under copy_2 # os.chdir('/home/serene/PycharmProjects/srnn-pytorch-master/') # context_factor is an experiment for including potential destinations of pedestrians in the graph. # did not yield good improvement os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master') # os.chdir('/home/siri0005/srnn-pytorch-master') # base_dir = '../fine_obstacle/prelu/p_02/' base_dir = '../fine_obstacle/prelu/p_02/' # '../MultiNodeAttn_HH/' # os.chdir('/home/serene/Documents/KITTIData/GT') datasets = [i for i in [0,1,2,3,4]] # Remove the leave out dataset from the datasets datasets.remove(args.leaveDataset) # datasets = [0] # args.leaveDataset = 0 # Construct the DataLoader object dataloader = DataLoader(args.batch_size, args.seq_length + 1, datasets, forcePreProcess=True) # Construct the ST-graph object stgraph = ST_GRAPH(1, args.seq_length + 1) # Log directory # log_directory = './log/world_data/normalized_01/' log_directory = base_dir+'log/' log_directory += str(args.leaveDataset) + '/' log_directory += 'log_attention/' # Logging file log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w') log_file = open(os.path.join(log_directory, 'val.txt'), 'w') # Save directory # save_directory = './save/world_data/normalized_01/' save_directory = base_dir+'save/' save_directory += str(args.leaveDataset)+'/' save_directory += 'save_attention/' # log RELU parameter param_log_dir = save_directory + 'param_log.txt' param_log = open(param_log_dir , 'w') # Open the configuration file with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar') # Initialize net net = SRNN(args) # srnn_model = SRNN(args) # net = torch.nn.DataParallel(srnn_model) # CUDA_VISIBLE_DEVICES = 1 net.cuda() optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True) # learning_rate = args.learning_rate print('Training begin') best_val_loss = 100 best_epoch = 0 # start_epoch = 0 # if args.leaveDataset == 3: # start_epoch =checkpoint_path(0) # ckp = torch.load(checkpoint_path(3)) # net.load_state_dict(ckp['state_dict']) # optimizer.load_state_dict(ckp['optimizer_state_dict']) # ckp = torch.load(checkpoint_path(4)) # net.load_state_dict(ckp['state_dict']) # optimizer.load_state_dict(ckp['optimizer_state_dict']) # last_epoch = ckp['epoch'] # rang = range(last_epoch, args.num_epochs, 1) rang = range(args.num_epochs) # Training for epoch in rang: dataloader.reset_batch_pointer(valid=False) loss_epoch = 0 # if args.leaveDataset == 2 and epoch == 0: # dataloader.num_batches = dataloader.num_batches + start_epoch # epoch += start_epoch # For each batch stateful = True # flag that controls transmission of previous hidden states to current hidden states vectors. # Part of statefulness for batch in range( dataloader.num_batches): start = time.time() # Get batch data x, _, _, d = dataloader.next_batch(randomUpdate=True) ## shuffling input + stateless lstm # Loss for this batch loss_batch = 0 # For each sequence in the batch for sequence in range(dataloader.batch_size): # Construct the graph for the current sequence stgraph.readGraph([x[sequence]],d, args.distance_thresh) nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() # # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda() obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda() ## Modification : reset hidden and cell states only once after every batch ; keeping states updated during sequences 31st JAN # Define hidden states numNodes = nodes.size()[1] numObsNodes = obsNodes.size()[1] # if not stateful: hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() ## new update : 25th JAN , let the hidden state transition begin with negative ones ## such initialization did not lead to any new learning results, network converged quickly and loss did not decrease below # -6 hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda() hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda() cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda() cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda() net.zero_grad() optimizer.zero_grad() # Forward prop # _ = \ outputs, h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn, o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn, _ = \ net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs, obsNodes[:args.seq_length], obsEdges[:args.seq_length], obsNodesPresent[:-1], obsEdgesPresent[:-1] ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs, cell_states_obs_node_RNNs, cell_states_obs_edge_RNNs) # # else: # # if len(nodes) == len(hidden_states_node_RNNs): # no additional nodes introduced in graph # # hidden_states_node_RNNs = Variable(h_node_rnn).cuda() # # cell_states_node_RNNs = Variable(cell_node_rnn).cuda() # # else: # for now number of nodes is only increasing in time as new pedestrians are detected in the scene # # pad_size = len(nodes) - len(hidden_states_node_RNNs) # cell_states_node_RNNs = Variable(np.pad(cell_node_rnn, pad_size)).cuda() # if sequence > 0: # hidden_states_node_RNNs = Variable(h_node_rnn).resize(hidden_states_node_RNNs.cpu().size()) # hidden_states_edge_RNNs = h_edge_rnn # cell_states_node_RNNs = cell_node_rnn # cell_states_edge_RNNs = cell_edge_rnn # new_num_nodes = h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0] # if h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0] >=1: # np.pad(hidden_states_node_RNNs.cpu() , new_num_nodes, mode='constant').cuda() # hidden_states_obs_node_RNNs = o_h_node_rnn # hidden_states_obs_edge_RNNs = o_h_edge_rnn # cell_states_obs_node_RNNs = o_cell_node_rnn # cell_states_obs_edge_RNNs = o_cell_edge_rnn # Zero out the gradients # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Compute gradients loss.backward(retain_variables=True) param_log.write(str(net.alpha.data[0]) + '\n') # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset the stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch print( '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start)) # Compute loss for the entire epoch loss_epoch /= dataloader.num_batches # Log it log_file_curve.write(str(epoch)+','+str(loss_epoch)+',') # Validation dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 # dataloader.valid_num_batches = dataloader.valid_num_batches + start_epoch for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False)## stateless lstm without shuffling # Loss for this batch loss_batch = 0 for sequence in range(dataloader.batch_size): stgraph.readGraph([x[sequence]], d, args.distance_thresh) nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() # # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() edges = Variable(torch.from_numpy(edges).float()).cuda() obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda() obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda() hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda() cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda() numObsNodes = obsNodes.size()[1] hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda() hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda() cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda() cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda() # outputs, h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn ,o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn, _= net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs , obsNodes[:args.seq_length], obsEdges[:args.seq_length], obsNodesPresent[:-1], obsEdgesPresent[:-1] ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs, cell_states_obs_node_RNNs, cell_states_obs_edge_RNNs) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch # Record best epoch and best validation loss print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch)) print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch) + '\n') # Save the model after each epoch print('Saving model') torch.save({ 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, checkpoint_path(epoch)) # Record the best epoch and best validation loss overall print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch) + ',' + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close() param_log.close()
def train(args): print("INPUT SEQUENCE LENGTH: {}".format(args.seq_length)) print("OUTPUT SEQUENCE LENGTH: {}".format(args.pred_length)) # Construct the DataLoader object dataloader = DataLoader(args.batch_size, args.seq_length + 1, forcePreProcess=False) # Construct the ST-graph object stgraph = ST_GRAPH(1, args.seq_length + 1) # Log directory log_directory = "../log/" # Logging file log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w") log_file = open(os.path.join(log_directory, "val.txt"), "w") # Save directory save_directory = "../../save_weight/" # Open the configuration file # 현재 argument 세팅 저장. with open(os.path.join(save_directory, "config.pkl"), "wb") as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, "4Dgraph.S-{}.P-{}.srnn_model_"\ .format(args.seq_length, args.pred_length)\ + str(x) + ".tar") # Initialize net net = SRNN(args) if args.use_cuda: net = net.cuda() optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5) # learning_rate = args.learning_rate logging.info("Training begin") best_val_loss = 100 best_epoch = 0 # Training for epoch in range(args.num_epochs): dataloader.reset_batch_pointer(valid=False) # Initialization loss_epoch = 0 # For each batch # dataloader.num_batches = 10. 1 epoch have 10 batches for batch in range(dataloader.num_batches): start = time.time() # Get batch data, mini-batch x, _, _, d = dataloader.next_batch(randomUpdate=True) # Loss for this batch loss_batch = 0 # For each sequence in the batch for sequence in range( dataloader.batch_size): #미니 배치에 있는 각 sequence 데이터에 대한 처리. # Construct the graph for the current sequence in {nodes, edges} stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) #미니 배치에 있는 각 sequence의 graph 정보. ##### Convert to cuda variables ##### nodes = Variable(torch.from_numpy(nodes).float()) # nodes[0] represent all the person(object)'s corrdinate show up in frame 0. if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] #numNodes hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() hidden_states_super_node_RNNs = Variable( #NOTE: 0 for peds., 1 for Bic., 2 for Veh. torch.zeros(3, args.node_rnn_size)) if args.use_cuda: hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda( ) cell_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda( ) hidden_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: hidden_states_super_node_Edge_RNNs = ( hidden_states_super_node_Edge_RNNs.cuda()) cell_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: cell_states_super_node_Edge_RNNs = ( cell_states_super_node_Edge_RNNs.cuda()) # Zero out the gradients // Initialization Step net.zero_grad() optimizer.zero_grad() # Forward prop outputs, _, _, _, _, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs, hidden_states_super_node_RNNs, hidden_states_super_node_Edge_RNNs, cell_states_super_node_RNNs, cell_states_super_node_Edge_RNNs, ) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.item() # embed() # Compute gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset the stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size ##### NOTE: Expected Loss; E[L] loss_epoch += loss_batch logging.info( "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}". format( epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start, )) # Compute loss for the entire epoch loss_epoch /= dataloader.num_batches # Log it log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",") ##################### Validation Part ##################### dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False) # Loss for this batch loss_batch = 0 for sequence in range(dataloader.batch_size): stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()) if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() hidden_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda( ) cell_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda( ) hidden_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: hidden_states_super_node_Edge_RNNs = ( hidden_states_super_node_Edge_RNNs.cuda()) cell_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: cell_states_super_node_Edge_RNNs = ( cell_states_super_node_Edge_RNNs.cuda()) outputs, _, _, _, _, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs, hidden_states_super_node_RNNs, hidden_states_super_node_Edge_RNNs, cell_states_super_node_RNNs, cell_states_super_node_Edge_RNNs, ) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.item() # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size #### NOTE: Expected Loss; E[L] loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch # Record best epoch and best validation loss logging.info("(epoch {}), valid_loss = {:.3f}".format( epoch, loss_epoch)) logging.info("Best epoch {}, Best validation loss {}".format( best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch) + "\n") # Save the model after each epoch logging.info("Saving model") torch.save( { "epoch": epoch, "state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, checkpoint_path(epoch), ) # Record the best epoch and best validation loss overall logging.info("Best epoch {}, Best validation loss {}".format( best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch) + "," + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()
def train(args): # Construct the DataLoader object dataloader = DataLoader(args.batch_size, args.seq_length + 1, forcePreProcess=False) # Construct the ST-graph object stgraph = ST_GRAPH(1, args.seq_length + 1) # Log directory log_directory = os.getcwd() + "\\log\\" if not os.path.exists(log_directory): os.makedirs(log_directory) # Logging file log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w+") log_file = open(os.path.join(log_directory, "val.txt"), "w") # Save directory save_directory = os.getcwd() + "\\save\\" if not os.path.exists(save_directory): os.makedirs(save_directory) # Open the configuration file with open(os.path.join(save_directory, "config.pkl"), "wb") as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, "srnn_model_" + str(x) + ".tar") # Initialize net net = SRNN(args) net.load_state_dict(torch.load('modelref/srnn_model_271.tar'), strict=False) if args.use_cuda: net = net.cuda() optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5) # learning_rate = args.learning_rate logging.info("Training begin") best_val_loss = 100 best_epoch = 0 global plotter plotter = VisdomLinePlotter(env_name='main') # Training for epoch in range(args.num_epochs): dataloader.reset_batch_pointer(valid=False) loss_epoch = 0 # For each batch # dataloader.num_batches = 10. 1 epoch have 10 batches for batch in range(dataloader.num_batches): start = time.time() # Get batch data x, _, _, d = dataloader.next_batch(randomUpdate=True) # Loss for this batch loss_batch = 0 # For each sequence in the batch for sequence in range(dataloader.batch_size): # Construct the graph for the current sequence stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()) # nodes[0] represent all the person's corrdinate show up in frame 0. if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() hidden_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda( ) cell_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda( ) hidden_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: hidden_states_super_node_Edge_RNNs = ( hidden_states_super_node_Edge_RNNs.cuda()) cell_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: cell_states_super_node_Edge_RNNs = ( cell_states_super_node_Edge_RNNs.cuda()) # Zero out the gradients net.zero_grad() optimizer.zero_grad() # Forward prop outputs, _, _, _, _, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs, hidden_states_super_node_RNNs, hidden_states_super_node_Edge_RNNs, cell_states_super_node_RNNs, cell_states_super_node_Edge_RNNs, ) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.item() # embed() # Compute gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset the stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch logging.info( "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}". format( epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start, )) # Compute loss for the entire epoch loss_epoch /= dataloader.num_batches plotter.plot('loss', 'train', 'Class Loss', epoch, loss_epoch) # Log it log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",") # Validation dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False) # Loss for this batch loss_batch = 0 for sequence in range(dataloader.batch_size): stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence( ) # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()) if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable( torch.zeros(numNodes, args.node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable( torch.zeros(numNodes * numNodes, args.edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() hidden_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda( ) cell_states_super_node_RNNs = Variable( torch.zeros(3, args.node_rnn_size)) if args.use_cuda: cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda( ) hidden_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: hidden_states_super_node_Edge_RNNs = ( hidden_states_super_node_Edge_RNNs.cuda()) cell_states_super_node_Edge_RNNs = Variable( torch.zeros(3, args.edge_rnn_size)) if args.use_cuda: cell_states_super_node_Edge_RNNs = ( cell_states_super_node_Edge_RNNs.cuda()) outputs, _, _, _, _, _, _, _, _, _ = net( nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs, hidden_states_super_node_RNNs, hidden_states_super_node_Edge_RNNs, cell_states_super_node_RNNs, cell_states_super_node_Edge_RNNs, ) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.item() # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch plotter.plot('loss', 'val', 'Class Loss', epoch, loss_epoch) # Record best epoch and best validation loss logging.info("(epoch {}), valid_loss = {:.3f}".format( epoch, loss_epoch)) logging.info("Best epoch {}, Best validation loss {}".format( best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch) + "\n") # Save the model after each epoch logging.info("Saving model") torch.save( { "epoch": epoch, "state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, checkpoint_path(epoch), ) # Record the best epoch and best validation loss overall logging.info("Best epoch {}, Best validation loss {}".format( best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch) + "," + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()
def train(args): # Construct the DataLoader object ## args: (batch_size=50, seq_length=5, datasets=[0, 1, 2, 3, 4, 5, 6], forcePreProcess=False, infer=False) dataloader = DataLoader(args.batch_size, args.seq_length + 1, args.train_dataset, forcePreProcess=True) ##** not sure why seq_length+1 # Construct the ST-graph object ## args: (batch_size=50, seq_length=5) stgraph = ST_GRAPH(1, args.seq_length + 1) ##**not sure why batch_size=1 and seq_length+1 # Log directory log_directory = 'log/trainedOn_'+ str(args.train_dataset) if not os.path.exists(log_directory): os.makedirs(log_directory) # Logging file log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w') log_file = open(os.path.join(log_directory, 'val.txt'), 'w') # Save directory for saving the model save_directory = 'save/trainedOn_'+str(args.train_dataset) if not os.path.exists(save_directory): os.makedirs(save_directory) # Open the configuration file ## store arguments from parser with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Path to store the checkpoint file, i.e. model after the particular epoch def checkpoint_path(x): return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar') # Initialize net net = SRNN(args) if args.use_cuda: net = net.cuda() # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True) optimizer = torch.optim.Adagrad(net.parameters()) learning_rate = args.learning_rate print('Training begin') best_val_loss = 100 best_epoch = 0 # Training for epoch in range(args.num_epochs): dataloader.reset_batch_pointer(valid=False) loss_epoch = 0 # For each batch for batch in range(dataloader.num_batches): start = time.time() # Get batch data ## Format: ## x_batch: input sequence of length self.seq_length ## y_batch: output seq of same length shifted y 1 step in time ## frame_batch: frame IDs in the batch ## d: current position of dataset pointer (points to the next batch to be loaded) x, _, _, d = dataloader.next_batch(randomUpdate=True) # Loss for this batch loss_batch = 0 # For each sequence in the batch for sequence in range(dataloader.batch_size): # Construct the graph for the current sequence stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence() # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()) if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() # Zero out the gradients net.zero_grad() optimizer.zero_grad() # Forward prop outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Compute gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset the stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch print( '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start)) # Compute loss for the entire epoch loss_epoch /= dataloader.num_batches # Log it log_file_curve.write(str(epoch)+','+str(loss_epoch)+',') # Validation dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False) # Loss for this batch loss_batch = 0 for sequence in range(dataloader.batch_size): stgraph.readGraph([x[sequence]]) nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence() # Convert to cuda variables nodes = Variable(torch.from_numpy(nodes).float()) if args.use_cuda: nodes = nodes.cuda() edges = Variable(torch.from_numpy(edges).float()) if args.use_cuda: edges = edges.cuda() # Define hidden states numNodes = nodes.size()[1] hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)) if args.use_cuda: hidden_states_node_RNNs = hidden_states_node_RNNs.cuda() hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size)) if args.use_cuda: hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda() cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)) if args.use_cuda: cell_states_node_RNNs = cell_states_node_RNNs.cuda() cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size)) if args.use_cuda: cell_states_edge_RNNs = cell_states_edge_RNNs.cuda() outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data.item() # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches #Saving the model\ if loss_epoch < best_val_loss or args.save_every: print('Saving model') torch.save({ 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, checkpoint_path(epoch)) #save_best_model overwriting the earlier file #if loss_epoch < best_val_loss: # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch # Record best epoch and best validation loss print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch)) print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss)) # Log it log_file_curve.write(str(loss_epoch)+'\n') # Record the best epoch and best validation loss overall print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss)) # Log it log_file.write(str(best_epoch)+','+str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()