def train(args):
    datasets = [i for i in range(5)]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            datasets,
                            forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(args.batch_size, args.seq_length + 1)

    # Log directory
    log_directory = 'log/'
    log_directory += str(args.leaveDataset) + '/'

    # Logging files
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    save_directory = 'save/'
    save_directory += str(args.leaveDataset) + '/'

    # Dump the arguments into the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory,
                            'social_lstm_model_' + str(x) + '.tar')

    # Initialize net
    net = SocialLSTM(args)
    net.cuda()

    optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate)
    learning_rate = args.learning_rate

    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, d = dataloader.next_batch()

            # Construct the stgraph
            stgraph.readGraph(x)

            loss_batch = 0

            # For each sequence
            for sequence in range(dataloader.batch_size):
                # Get the data corresponding to the current sequence
                x_seq, d_seq = x[sequence], d[sequence]

                # Dataset dimensions
                if d_seq == 0 and datasets[0] == 0:
                    dataset_data = [640, 480]
                else:
                    dataset_data = [720, 576]

                # Compute grid masks
                grid_seq = getSequenceGridMask(x_seq, dataset_data,
                                               args.neighborhood_size,
                                               args.grid_size)
                obst_seq = get_seq_mask(x_seq, d_seq, dataset_data,
                                        args.neighborhood_size, args.grid_size)

                # Get the node features and nodes present from stgraph
                nodes, _, nodesPresent, _ = stgraph.getSequence(sequence)

                # Construct variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                # nodes = Variable(torch.from_numpy(nodes).float())
                numNodes = nodes.size()[1]
                hidden_states = Variable(torch.zeros(numNodes,
                                                     args.rnn_size)).cuda()
                cell_states = Variable(torch.zeros(numNodes,
                                                   args.rnn_size)).cuda()
                # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                # cell_states = Variable(torch.zeros(numNodes, args.rnn_size))

                # Zero out gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1],
                                    nodesPresent[:-1], hidden_states,
                                    cell_states)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

            # Reset stgraph
            stgraph.reset()
            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.
                  format(epoch * dataloader.num_batches + batch,
                         args.num_epochs * dataloader.num_batches, epoch,
                         loss_batch, end - start))

        loss_epoch /= dataloader.num_batches
        # Log loss values
        log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Read the st graph from data
            stgraph.readGraph(x)

            # Loss for this batch
            loss_batch = 0

            # For each sequence
            for sequence in range(dataloader.batch_size):
                # Get data corresponding to the current sequence
                x_seq, d_seq = x[sequence], d[sequence]

                # Dataset dimensions
                if d_seq == 0 and datasets[0] == 0:
                    dataset_data = [640, 480]
                else:
                    dataset_data = [720, 576]

                # Compute grid masks
                grid_seq = getSequenceGridMask(x_seq, dataset_data,
                                               args.neighborhood_size,
                                               args.grid_size)
                obst_seq = get_seq_mask(x_seq, d_seq, dataset_data,
                                        args.neighborhood_size, args.grid_size)
                # Get node features and nodes present from stgraph
                nodes, _, nodesPresent, _ = stgraph.getSequence(sequence)

                # Construct variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                # nodes = Variable(torch.from_numpy(nodes).float())
                numNodes = nodes.size()[1]
                # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                # cell_states = Variable(torch.zeros(numNodes, args.rnn_size))
                hidden_states = Variable(torch.zeros(numNodes,
                                                     args.rnn_size)).cuda()
                cell_states = Variable(torch.zeros(numNodes,
                                                   args.rnn_size)).cuda()

                # Forward prop
                outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1],
                                    nodesPresent[:-1], hidden_states,
                                    cell_states)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

            # Reset the stgraph
            stgraph.reset()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch', best_epoch, 'Best validation loss', best_val_loss)
        log_file_curve.write(str(loss_epoch) + '\n')

        # Save the model after each epoch
        print('Saving model')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch))

    print('Best epoch', best_epoch, 'Best validation Loss', best_val_loss)
    # Log the best epoch and best validation loss
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length', type=int, default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length', type=int, default=12,
                        help='Predicted length of the trajectory')
    # Test dataset
    parser.add_argument('--test_dataset', type=int, default=3,
                        help='Dataset to be tested on')

    # Model to be loaded
    parser.add_argument('--epoch', type=int, default=107,
                        help='Epoch of model to be loaded')

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = '/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/'+ str(sample_args.test_dataset) + '/'

    save_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/1/'
    ouput_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/'


    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SocialLSTM(saved_args, True)
    net.cuda()

    # Get the checkpoint path
    checkpoint_path = os.path.join(save_directory, 'social_lstm_model_'+str(sample_args.epoch)+'.tar')
    # checkpoint_path = os.path.join(save_directory, 'srnn_model.tar')
    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at epoch', model_epoch)

    #homography
    H = np.loadtxt(H_path[sample_args.test_dataset])

    # Test dataset
    dataset = [sample_args.test_dataset]

    # Create the DataLoader object
    dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    total_error = 0
    final_error = 0

    # For each batch
    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get data
        x, _, d = dataloader.next_batch(randomUpdate=False)

        # Get the sequence
        x_seq, d_seq = x[0], d[0]

        # Dimensions of the dataset
        if d_seq == 0 and dataset[0] == 0:
            dimensions = [640, 480]
        else:
            dimensions = [720, 576]

        dimensions=[1224,370]

        # Get the grid masks for the sequence
        grid_seq = getSequenceGridMask(x_seq, dimensions, saved_args.neighborhood_size, saved_args.grid_size)

        # Construct ST graph
        stgraph.readGraph(x)

        # Get nodes and nodesPresent
        nodes, _, nodesPresent, _ = stgraph.getSequence(0)
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()

        # Extract the observed part of the trajectories
        obs_nodes, obs_nodesPresent, obs_grid = nodes[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], grid_seq[:sample_args.obs_length]

        # The sample function
        ret_nodes = sample(obs_nodes, obs_nodesPresent, obs_grid, sample_args, net, nodes, nodesPresent, grid_seq, saved_args, dimensions)
        #print(nodes[sample_args.obs_length:].data)
        # Record the mean and final displacement error
        total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset)
        final_error += get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset)

        end = time.time()

        print('Processed trajectory number : ', batch, 'out of', dataloader.num_batches, 'trajectories in time', end - start)

        results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length))

        # Reset the ST graph
        stgraph.reset()

    print('Total mean error of the model is ', total_error / dataloader.num_batches)
    print('Total final error of the model is ', final_error / dataloader.num_batches)

    print('Saving results')
    with open(os.path.join(ouput_directory, 'results.pkl'), 'wb') as f:
        pickle.dump(results, f)