def train(args): datasets = [i for i in range(5)] # Remove the leave out dataset from the datasets datasets.remove(args.leaveDataset) # Construct the DataLoader object dataloader = DataLoader(args.batch_size, args.seq_length + 1, datasets, forcePreProcess=True) # Construct the ST-graph object stgraph = ST_GRAPH(args.batch_size, args.seq_length + 1) # Log directory log_directory = 'log/' log_directory += str(args.leaveDataset) + '/' # Logging files log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w') log_file = open(os.path.join(log_directory, 'val.txt'), 'w') # Save directory save_directory = 'save/' save_directory += str(args.leaveDataset) + '/' # Dump the arguments into the configuration file with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Path to store the checkpoint file def checkpoint_path(x): return os.path.join(save_directory, 'social_lstm_model_' + str(x) + '.tar') # Initialize net net = SocialLSTM(args) net.cuda() optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate) learning_rate = args.learning_rate print('Training begin') best_val_loss = 100 best_epoch = 0 # Training for epoch in range(args.num_epochs): dataloader.reset_batch_pointer(valid=False) loss_epoch = 0 # For each batch for batch in range(dataloader.num_batches): start = time.time() # Get batch data x, _, d = dataloader.next_batch() # Construct the stgraph stgraph.readGraph(x) loss_batch = 0 # For each sequence for sequence in range(dataloader.batch_size): # Get the data corresponding to the current sequence x_seq, d_seq = x[sequence], d[sequence] # Dataset dimensions if d_seq == 0 and datasets[0] == 0: dataset_data = [640, 480] else: dataset_data = [720, 576] # Compute grid masks grid_seq = getSequenceGridMask(x_seq, dataset_data, args.neighborhood_size, args.grid_size) obst_seq = get_seq_mask(x_seq, d_seq, dataset_data, args.neighborhood_size, args.grid_size) # Get the node features and nodes present from stgraph nodes, _, nodesPresent, _ = stgraph.getSequence(sequence) # Construct variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() # nodes = Variable(torch.from_numpy(nodes).float()) numNodes = nodes.size()[1] hidden_states = Variable(torch.zeros(numNodes, args.rnn_size)).cuda() cell_states = Variable(torch.zeros(numNodes, args.rnn_size)).cuda() # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size)) # cell_states = Variable(torch.zeros(numNodes, args.rnn_size)) # Zero out gradients net.zero_grad() optimizer.zero_grad() # Forward prop outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1], nodesPresent[:-1], hidden_states, cell_states) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Compute gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip) # Update parameters optimizer.step() # Reset stgraph stgraph.reset() end = time.time() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'. format(epoch * dataloader.num_batches + batch, args.num_epochs * dataloader.num_batches, epoch, loss_batch, end - start)) loss_epoch /= dataloader.num_batches # Log loss values log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',') # Validation dataloader.reset_batch_pointer(valid=True) loss_epoch = 0 # For each batch for batch in range(dataloader.valid_num_batches): # Get batch data x, _, d = dataloader.next_valid_batch(randomUpdate=False) # Read the st graph from data stgraph.readGraph(x) # Loss for this batch loss_batch = 0 # For each sequence for sequence in range(dataloader.batch_size): # Get data corresponding to the current sequence x_seq, d_seq = x[sequence], d[sequence] # Dataset dimensions if d_seq == 0 and datasets[0] == 0: dataset_data = [640, 480] else: dataset_data = [720, 576] # Compute grid masks grid_seq = getSequenceGridMask(x_seq, dataset_data, args.neighborhood_size, args.grid_size) obst_seq = get_seq_mask(x_seq, d_seq, dataset_data, args.neighborhood_size, args.grid_size) # Get node features and nodes present from stgraph nodes, _, nodesPresent, _ = stgraph.getSequence(sequence) # Construct variables nodes = Variable(torch.from_numpy(nodes).float()).cuda() # nodes = Variable(torch.from_numpy(nodes).float()) numNodes = nodes.size()[1] # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size)) # cell_states = Variable(torch.zeros(numNodes, args.rnn_size)) hidden_states = Variable(torch.zeros(numNodes, args.rnn_size)).cuda() cell_states = Variable(torch.zeros(numNodes, args.rnn_size)).cuda() # Forward prop outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1], nodesPresent[:-1], hidden_states, cell_states) # Compute loss loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length) loss_batch += loss.data[0] # Reset the stgraph stgraph.reset() loss_batch = loss_batch / dataloader.batch_size loss_epoch += loss_batch loss_epoch = loss_epoch / dataloader.valid_num_batches # Update best validation loss until now if loss_epoch < best_val_loss: best_val_loss = loss_epoch best_epoch = epoch print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch)) print('Best epoch', best_epoch, 'Best validation loss', best_val_loss) log_file_curve.write(str(loss_epoch) + '\n') # Save the model after each epoch print('Saving model') torch.save( { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, checkpoint_path(epoch)) print('Best epoch', best_epoch, 'Best validation Loss', best_val_loss) # Log the best epoch and best validation loss log_file.write(str(best_epoch) + ',' + str(best_val_loss)) # Close logging files log_file.close() log_file_curve.close()
def main(): parser = argparse.ArgumentParser() # Observed length of the trajectory parameter parser.add_argument('--obs_length', type=int, default=8, help='Observed length of the trajectory') # Predicted length of the trajectory parameter parser.add_argument('--pred_length', type=int, default=12, help='Predicted length of the trajectory') # Test dataset parser.add_argument('--test_dataset', type=int, default=3, help='Dataset to be tested on') # Model to be loaded parser.add_argument('--epoch', type=int, default=107, help='Epoch of model to be loaded') # Parse the parameters sample_args = parser.parse_args() # Save directory save_directory = '/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/'+ str(sample_args.test_dataset) + '/' save_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/1/' ouput_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/' # Define the path for the config file for saved args with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f: saved_args = pickle.load(f) # Initialize net net = SocialLSTM(saved_args, True) net.cuda() # Get the checkpoint path checkpoint_path = os.path.join(save_directory, 'social_lstm_model_'+str(sample_args.epoch)+'.tar') # checkpoint_path = os.path.join(save_directory, 'srnn_model.tar') if os.path.isfile(checkpoint_path): print('Loading checkpoint') checkpoint = torch.load(checkpoint_path) # model_iteration = checkpoint['iteration'] model_epoch = checkpoint['epoch'] net.load_state_dict(checkpoint['state_dict']) print('Loaded checkpoint at epoch', model_epoch) #homography H = np.loadtxt(H_path[sample_args.test_dataset]) # Test dataset dataset = [sample_args.test_dataset] # Create the DataLoader object dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True) dataloader.reset_batch_pointer() # Construct the ST-graph object stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length) results = [] # Variable to maintain total error total_error = 0 final_error = 0 # For each batch for batch in range(dataloader.num_batches): start = time.time() # Get data x, _, d = dataloader.next_batch(randomUpdate=False) # Get the sequence x_seq, d_seq = x[0], d[0] # Dimensions of the dataset if d_seq == 0 and dataset[0] == 0: dimensions = [640, 480] else: dimensions = [720, 576] dimensions=[1224,370] # Get the grid masks for the sequence grid_seq = getSequenceGridMask(x_seq, dimensions, saved_args.neighborhood_size, saved_args.grid_size) # Construct ST graph stgraph.readGraph(x) # Get nodes and nodesPresent nodes, _, nodesPresent, _ = stgraph.getSequence(0) nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda() # Extract the observed part of the trajectories obs_nodes, obs_nodesPresent, obs_grid = nodes[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], grid_seq[:sample_args.obs_length] # The sample function ret_nodes = sample(obs_nodes, obs_nodesPresent, obs_grid, sample_args, net, nodes, nodesPresent, grid_seq, saved_args, dimensions) #print(nodes[sample_args.obs_length:].data) # Record the mean and final displacement error total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset) final_error += get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset) end = time.time() print('Processed trajectory number : ', batch, 'out of', dataloader.num_batches, 'trajectories in time', end - start) results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length)) # Reset the ST graph stgraph.reset() print('Total mean error of the model is ', total_error / dataloader.num_batches) print('Total final error of the model is ', final_error / dataloader.num_batches) print('Saving results') with open(os.path.join(ouput_directory, 'results.pkl'), 'wb') as f: pickle.dump(results, f)