Ejemplo n.º 1
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length',
                        type=int,
                        default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length',
                        type=int,
                        default=12,
                        help='Predicted length of the trajectory')

    # Train Dataset
    # Use like:
    # python transpose_inrange.py --train_dataset index_1 index_2 ...
    parser.add_argument(
        '-l',
        '--train_dataset',
        nargs='+',
        help=
        '<Required> training dataset(s) the model is trained on: --train_dataset index_1 index_2 ...',
        default=[0, 1, 2, 4],
        type=int)

    # Test dataset
    parser.add_argument('--test_dataset',
                        type=int,
                        default=3,
                        help='Dataset to be tested on')

    # Model to be loaded
    parser.add_argument('--epoch',
                        type=int,
                        default=26,
                        help='Epoch of model to be loaded')

    # Use GPU or not
    parser.add_argument('--use_cuda',
                        action="store_true",
                        default=False,
                        help="Use GPU or CPU")

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    load_directory = 'save/'
    load_directory += 'trainedOn_' + str(sample_args.train_dataset)

    # Define the path for the config file for saved args
    ## Arguments of parser while traning
    with open(os.path.join(load_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    if saved_args.use_cuda:
        net = net.cuda()

    checkpoint_path = os.path.join(
        load_directory, 'srnn_model_' + str(sample_args.epoch) + '.tar')

    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at {}'.format(model_epoch))

    # Dataset to get data from
    dataset = [sample_args.test_dataset]

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            dataset,
                            True,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    total_error = 0
    final_error = 0

    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get the next batch
        x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

        # Construct ST graph
        stgraph.readGraph(x)

        nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

        # Convert to cuda variables
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True)
        edges = Variable(torch.from_numpy(edges).float(), volatile=True)
        if saved_args.use_cuda:
            nodes = nodes.cuda()
            edges = edges.cuda()

        # Separate out the observed part of the trajectory
        obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:
                                                                         sample_args
                                                                         .
                                                                         obs_length], edges[:
                                                                                            sample_args
                                                                                            .
                                                                                            obs_length], nodesPresent[:
                                                                                                                      sample_args
                                                                                                                      .
                                                                                                                      obs_length], edgesPresent[:
                                                                                                                                                sample_args
                                                                                                                                                .
                                                                                                                                                obs_length]

        # Sample function
        ret_nodes, ret_attn, ret_new_attn = sample(obs_nodes, obs_edges,
                                                   obs_nodesPresent,
                                                   obs_edgesPresent,
                                                   sample_args, net, nodes,
                                                   edges, nodesPresent)

        # Compute mean and final displacement error
        total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data,
                                      nodes[sample_args.obs_length:].data,
                                      nodesPresent[sample_args.obs_length - 1],
                                      nodesPresent[sample_args.obs_length:],
                                      saved_args.use_cuda)
        final_error += get_final_error(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:])

        end = time.time()

        print('Processed trajectory number : ', batch, 'out of',
              dataloader.num_batches, 'trajectories in time', end - start)

        # Store results
        if saved_args.use_cuda:
            results.append(
                (nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(),
                 nodesPresent, sample_args.obs_length, ret_attn, ret_new_attn,
                 frameIDs))
        else:
            results.append(
                (nodes.data.numpy(), ret_nodes.data.numpy(), nodesPresent,
                 sample_args.obs_length, ret_attn, ret_new_attn, frameIDs))

        # Reset the ST graph
        stgraph.reset()

    print('Total mean error of the model is ',
          total_error / dataloader.num_batches)
    print('Total final error of the model is ',
          final_error / dataloader.num_batches)

    print('Saving results')
    save_directory = load_directory + '/testedOn_' + str(
        sample_args.test_dataset)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
        pickle.dump(results, f)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument("--obs_length",
                        type=int,
                        default=4,
                        help="Observed length of the trajectory")
    # Predicted length of the trajectory parameter
    parser.add_argument("--pred_length",
                        type=int,
                        default=6,
                        help="Predicted length of the trajectory")
    # Model to be loaded
    parser.add_argument("--epoch",
                        type=int,
                        default=233,
                        help="Epoch of model to be loaded")

    # Use GPU or not
    parser.add_argument("--use_cuda",
                        action="store_true",
                        default=True,
                        help="Use GPU or CPU")

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = os.getcwd() + "\\save\\"
    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, "config.pkl"), "rb") as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    if saved_args.use_cuda:
        net = net.cuda()

    checkpoint_path = os.path.join(
        save_directory, "srnn_model_" + str(sample_args.epoch) + ".tar")

    if os.path.isfile(checkpoint_path):
        print("Loading checkpoint")
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint["epoch"]
        net.load_state_dict(checkpoint["state_dict"])
        print("Loaded checkpoint at {}".format(model_epoch))

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    # total_error = 0
    # final_error = 0
    avg_ped_error = 0
    final_ped_error = 0
    avg_bic_error = 0
    final_bic_error = 0
    avg_car_error = 0
    final_car_error = 0

    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get the next batch
        x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

        # Construct ST graph
        stgraph.readGraph(x)

        nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

        # Convert to cuda variables
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True)
        edges = Variable(torch.from_numpy(edges).float(), volatile=True)
        if saved_args.use_cuda:
            nodes = nodes.cuda()
            edges = edges.cuda()

        # Separate out the observed part of the trajectory
        obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = (
            nodes[:sample_args.obs_length],
            edges[:sample_args.obs_length],
            nodesPresent[:sample_args.obs_length],
            edgesPresent[:sample_args.obs_length],
        )

        # Sample function
        ret_nodes, ret_attn = sample(
            obs_nodes,
            obs_edges,
            obs_nodesPresent,
            obs_edgesPresent,
            sample_args,
            net,
            nodes,
            edges,
            nodesPresent,
        )

        # Compute mean and final displacement error
        """
        total_error += get_mean_error(
            ret_nodes[sample_args.obs_length :].data,
            nodes[sample_args.obs_length :].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length :],
            saved_args.use_cuda,
        )
        final_error += get_final_error(
            ret_nodes[sample_args.obs_length :].data,
            nodes[sample_args.obs_length :].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length :],
        )
        """
        avg_ped_error_delta, avg_bic_error_delta, avg_car_error_delta = get_mean_error_separately(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:],
            saved_args.use_cuda,
        )
        avg_ped_error += avg_ped_error_delta
        avg_bic_error += avg_bic_error_delta
        avg_car_error += avg_car_error_delta

        final_ped_error_delta, final_bic_error_delta, final_car_error_delta = get_final_error_separately(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:],
        )
        final_ped_error += final_ped_error_delta
        final_bic_error += final_bic_error_delta
        final_car_error += final_car_error_delta

        end = time.time()

        print(
            "Processed trajectory number : ",
            batch,
            "out of",
            dataloader.num_batches,
            "trajectories in time",
            end - start,
        )
        if saved_args.use_cuda:
            results.append((
                nodes.data.cpu().numpy(),
                ret_nodes.data.cpu().numpy(),
                nodesPresent,
                sample_args.obs_length,
                ret_attn,
                frameIDs,
            ))
        else:
            results.append((
                nodes.data.numpy(),
                ret_nodes.data.numpy(),
                nodesPresent,
                sample_args.obs_length,
                ret_attn,
                frameIDs,
            ))

        # Reset the ST graph
        stgraph.reset()

    # print("Total mean error of the model is ", total_error / dataloader.num_batches)
    # print(
    #    "Total final error of the model is ", final_error / dataloader.num_batches
    # )  # num_batches = 10
    print(
        "AVG disp error:     singapore-onenorth: {}       boston-seaport: {}        singapore-queenstown:{}"
        .format(
            avg_ped_error / dataloader.num_batches,
            avg_bic_error / dataloader.num_batches,
            avg_car_error / dataloader.num_batches,
        ))

    print("total average error:    {}".format(
        (avg_ped_error + avg_bic_error + avg_car_error) /
        dataloader.num_batches / 3))

    print(
        "Final disp error:   singapore-onenorth: {}       boston-seaport: {}        singapore-queenstown:{}"
        .format(
            final_ped_error / dataloader.num_batches,
            final_bic_error / dataloader.num_batches,
            final_car_error / dataloader.num_batches,
        ))
    print("total final error:    {}".format(
        (final_ped_error + final_bic_error + final_car_error) /
        dataloader.num_batches / 3))

    print("Saving results")
    with open(os.path.join(save_directory, "results.pkl"), "wb") as f:
        pickle.dump(results, f)
Ejemplo n.º 3
0
def train(args):
    datasets = [i for i in range(5)]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)
    # datasets = [0]
    # args.leaveDataset = 0

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            datasets,
                            forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = 'log_w_goal/'
    log_directory += str(args.leaveDataset) + '/'
    log_directory += 'log_attention'

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    save_directory = 'save_w_goal/'
    save_directory += str(args.leaveDataset) + '/'
    save_directory += 'save_attention'

    # Open the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_' + str(x) + '.tar')

    # Initialize net
    net = SRNN(args)
    net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)

    learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True)
            import pdb
            pdb.set_trace()

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _ = net(
                    nodes[:args.seq_length], edges[:args.seq_length],
                    nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs,
                    cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.
                  format(epoch * dataloader.num_batches + batch,
                         args.num_epochs * dataloader.num_batches, epoch,
                         loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                outputs, _, _, _, _, _ = net(
                    nodes[:args.seq_length], edges[:args.seq_length],
                    nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs,
                    cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.data[0]

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + '\n')

        # Save the model after each epoch
        print('Saving model')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch))

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Ejemplo n.º 4
0
def train(args):
    datasets = [i for i in range(5)]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            datasets,
                            forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(args.batch_size, args.seq_length + 1)

    # Log directory
    log_directory = 'log/'
    log_directory += str(args.leaveDataset) + '/'

    # Logging files
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    save_directory = 'save/'
    save_directory += str(args.leaveDataset) + '/'

    # Dump the arguments into the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory,
                            'social_lstm_model_' + str(x) + '.tar')

    # Initialize net
    net = SocialLSTM(args)
    net.cuda()

    optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate)
    learning_rate = args.learning_rate

    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, d = dataloader.next_batch()

            # Construct the stgraph
            stgraph.readGraph(x)

            loss_batch = 0

            # For each sequence
            for sequence in range(dataloader.batch_size):
                # Get the data corresponding to the current sequence
                x_seq, d_seq = x[sequence], d[sequence]

                # Dataset dimensions
                if d_seq == 0 and datasets[0] == 0:
                    dataset_data = [640, 480]
                else:
                    dataset_data = [720, 576]

                # Compute grid masks
                grid_seq = getSequenceGridMask(x_seq, dataset_data,
                                               args.neighborhood_size,
                                               args.grid_size)
                obst_seq = get_seq_mask(x_seq, d_seq, dataset_data,
                                        args.neighborhood_size, args.grid_size)

                # Get the node features and nodes present from stgraph
                nodes, _, nodesPresent, _ = stgraph.getSequence(sequence)

                # Construct variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                # nodes = Variable(torch.from_numpy(nodes).float())
                numNodes = nodes.size()[1]
                hidden_states = Variable(torch.zeros(numNodes,
                                                     args.rnn_size)).cuda()
                cell_states = Variable(torch.zeros(numNodes,
                                                   args.rnn_size)).cuda()
                # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                # cell_states = Variable(torch.zeros(numNodes, args.rnn_size))

                # Zero out gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1],
                                    nodesPresent[:-1], hidden_states,
                                    cell_states)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

            # Reset stgraph
            stgraph.reset()
            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.
                  format(epoch * dataloader.num_batches + batch,
                         args.num_epochs * dataloader.num_batches, epoch,
                         loss_batch, end - start))

        loss_epoch /= dataloader.num_batches
        # Log loss values
        log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Read the st graph from data
            stgraph.readGraph(x)

            # Loss for this batch
            loss_batch = 0

            # For each sequence
            for sequence in range(dataloader.batch_size):
                # Get data corresponding to the current sequence
                x_seq, d_seq = x[sequence], d[sequence]

                # Dataset dimensions
                if d_seq == 0 and datasets[0] == 0:
                    dataset_data = [640, 480]
                else:
                    dataset_data = [720, 576]

                # Compute grid masks
                grid_seq = getSequenceGridMask(x_seq, dataset_data,
                                               args.neighborhood_size,
                                               args.grid_size)
                obst_seq = get_seq_mask(x_seq, d_seq, dataset_data,
                                        args.neighborhood_size, args.grid_size)
                # Get node features and nodes present from stgraph
                nodes, _, nodesPresent, _ = stgraph.getSequence(sequence)

                # Construct variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                # nodes = Variable(torch.from_numpy(nodes).float())
                numNodes = nodes.size()[1]
                # hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                # cell_states = Variable(torch.zeros(numNodes, args.rnn_size))
                hidden_states = Variable(torch.zeros(numNodes,
                                                     args.rnn_size)).cuda()
                cell_states = Variable(torch.zeros(numNodes,
                                                   args.rnn_size)).cuda()

                # Forward prop
                outputs, _, _ = net(nodes[:-1], grid_seq[:-1], obst_seq[:-1],
                                    nodesPresent[:-1], hidden_states,
                                    cell_states)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

            # Reset the stgraph
            stgraph.reset()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch', best_epoch, 'Best validation loss', best_val_loss)
        log_file_curve.write(str(loss_epoch) + '\n')

        # Save the model after each epoch
        print('Saving model')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch))

    print('Best epoch', best_epoch, 'Best validation Loss', best_val_loss)
    # Log the best epoch and best validation loss
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Ejemplo n.º 5
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length',
                        type=int,
                        default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length',
                        type=int,
                        default=12,
                        help='Predicted length of the trajectory')
    # Test dataset
    parser.add_argument('--test_dataset',
                        type=int,
                        default=4,
                        help='Dataset to be tested on')

    parser.add_argument('--sample_dataset',
                        type=int,
                        default=4,
                        help='Dataset to be sampled on')

    # Model to be loaded
    parser.add_argument('--epoch',
                        type=int,
                        default=133,
                        help='Epoch of model to be loaded')
    #[109,149,124,80,128]
    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = '/home/hesl/PycharmProjects/srnn-pytorch/save/FixedPixel_150epochs_batchsize24_Pruned/'
    save_directory += str(sample_args.test_dataset) + '/'
    save_directory += 'save_attention'

    ouput_directory = '/home/hesl/PycharmProjects/srnn-pytorch/save/'
    #ouput_directory+= str(sample_args.test_dataset) + '/'
    ouput_directory = save_directory

    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    net.cuda()
    #net.forward()

    checkpoint_path = os.path.join(
        save_directory, 'srnn_model_' + str(sample_args.epoch) + '.tar')

    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at {}'.format(model_epoch))

    # homography
    H = np.loadtxt(H_path[sample_args.sample_dataset])

    # Dataset to get data from
    dataset = [sample_args.test_dataset]
    dataset = [sample_args.sample_dataset]

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            dataset,
                            True,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    NumberofSampling = 10

    for i in range(NumberofSampling):

        results = []

        # Variable to maintain total error
        total_error = 0
        final_error = 0

        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get the next batch
            x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

            # Construct ST graph
            stgraph.readGraph(x)

            nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

            # Convert to cuda variables
            nodes = Variable(torch.from_numpy(nodes).float(),
                             volatile=True).cuda()
            edges = Variable(torch.from_numpy(edges).float(),
                             volatile=True).cuda()

            # Separate out the observed part of the trajectory
            obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:
                                                                             sample_args
                                                                             .
                                                                             obs_length], edges[:
                                                                                                sample_args
                                                                                                .
                                                                                                obs_length], nodesPresent[:
                                                                                                                          sample_args
                                                                                                                          .
                                                                                                                          obs_length], edgesPresent[:
                                                                                                                                                    sample_args
                                                                                                                                                    .
                                                                                                                                                    obs_length]

            # Sample function
            ret_nodes, ret_attn = sample(obs_nodes, obs_edges,
                                         obs_nodesPresent, obs_edgesPresent,
                                         sample_args, net, nodes, edges,
                                         nodesPresent)

            # Compute mean and final displacement error
            total_error += get_mean_error(
                ret_nodes[sample_args.obs_length:].data,
                nodes[sample_args.obs_length:].data,
                nodesPresent[sample_args.obs_length - 1],
                nodesPresent[sample_args.obs_length:], H,
                sample_args.sample_dataset)
            final_error += get_final_error(
                ret_nodes[sample_args.obs_length:].data,
                nodes[sample_args.obs_length:].data,
                nodesPresent[sample_args.obs_length - 1],
                nodesPresent[sample_args.obs_length:], H,
                sample_args.sample_dataset)

            end = time.time()

            print('Processed trajectory number : ', batch, 'out of',
                  dataloader.num_batches, 'trajectories in time', end - start)

            # Store results
            results.append(
                (nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(),
                 nodesPresent, sample_args.obs_length, ret_attn, frameIDs,
                 total_error / dataloader.num_batches,
                 final_error / dataloader.num_batches))

            # Reset the ST graph
            stgraph.reset()

        print('Total mean error of the model is ',
              total_error / dataloader.num_batches)
        print('Total final error of the model is ',
              final_error / dataloader.num_batches)

        current_mean_error = total_error / dataloader.num_batches
        current_final_error = final_error / dataloader.num_batches
        if i == 0:
            min_current_mean_error = current_mean_error
            min_current_final_error = current_final_error
            min_index = i
            print('Saving initial results on {}'.format(i))
            with open(os.path.join(ouput_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)
        else:
            if current_mean_error < min_current_mean_error:
                min_current_mean_error = current_mean_error
                min_current_final_error = current_final_error
                min_index = i
                print('Found Smaller Error on {}, Saving results'.format(i))
                print(
                    'Smaller current_mean_error"{} and current_final_error:{} and '
                    .format(current_mean_error, current_final_error))
                with open(os.path.join(ouput_directory, 'results.pkl'),
                          'wb') as f:
                    pickle.dump(results, f)

    print(
        'Minimum Total Mean Error is {} and Minimum Final Mean Error is {} on {}th Sampling'
        .format(min_current_mean_error, min_current_final_error, min_index))
Ejemplo n.º 6
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length', type=int, default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length', type=int, default=12,
                        help='Predicted length of the trajectory')
    # Test dataset
    parser.add_argument('--test_dataset', type=int, default=3,
                        help='Dataset to be tested on')

    # Model to be loaded
    parser.add_argument('--epoch', type=int, default=107,
                        help='Epoch of model to be loaded')

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = '/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/'+ str(sample_args.test_dataset) + '/'

    save_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/FixedPixel_Normalized_150epoch/1/'
    ouput_directory='/home/hesl/PycharmProjects/social-lstm-pytorch/save/'


    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SocialLSTM(saved_args, True)
    net.cuda()

    # Get the checkpoint path
    checkpoint_path = os.path.join(save_directory, 'social_lstm_model_'+str(sample_args.epoch)+'.tar')
    # checkpoint_path = os.path.join(save_directory, 'srnn_model.tar')
    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at epoch', model_epoch)

    #homography
    H = np.loadtxt(H_path[sample_args.test_dataset])

    # Test dataset
    dataset = [sample_args.test_dataset]

    # Create the DataLoader object
    dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    total_error = 0
    final_error = 0

    # For each batch
    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get data
        x, _, d = dataloader.next_batch(randomUpdate=False)

        # Get the sequence
        x_seq, d_seq = x[0], d[0]

        # Dimensions of the dataset
        if d_seq == 0 and dataset[0] == 0:
            dimensions = [640, 480]
        else:
            dimensions = [720, 576]

        dimensions=[1224,370]

        # Get the grid masks for the sequence
        grid_seq = getSequenceGridMask(x_seq, dimensions, saved_args.neighborhood_size, saved_args.grid_size)

        # Construct ST graph
        stgraph.readGraph(x)

        # Get nodes and nodesPresent
        nodes, _, nodesPresent, _ = stgraph.getSequence(0)
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()

        # Extract the observed part of the trajectories
        obs_nodes, obs_nodesPresent, obs_grid = nodes[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], grid_seq[:sample_args.obs_length]

        # The sample function
        ret_nodes = sample(obs_nodes, obs_nodesPresent, obs_grid, sample_args, net, nodes, nodesPresent, grid_seq, saved_args, dimensions)
        #print(nodes[sample_args.obs_length:].data)
        # Record the mean and final displacement error
        total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset)
        final_error += get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data, nodesPresent[sample_args.obs_length-1], nodesPresent[sample_args.obs_length:],H,sample_args.test_dataset)

        end = time.time()

        print('Processed trajectory number : ', batch, 'out of', dataloader.num_batches, 'trajectories in time', end - start)

        results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length))

        # Reset the ST graph
        stgraph.reset()

    print('Total mean error of the model is ', total_error / dataloader.num_batches)
    print('Total final error of the model is ', final_error / dataloader.num_batches)

    print('Saving results')
    with open(os.path.join(ouput_directory, 'results.pkl'), 'wb') as f:
        pickle.dump(results, f)
Ejemplo n.º 7
0
def train(args):
    ## 19th Feb : move training files for 3 && 4 to fine_obstacle under copy_2

    # os.chdir('/home/serene/PycharmProjects/srnn-pytorch-master/')
    # context_factor is an experiment for including potential destinations of pedestrians in the graph.
    # did not yield good improvement
    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    # os.chdir('/home/siri0005/srnn-pytorch-master')
    # base_dir = '../fine_obstacle/prelu/p_02/'

    base_dir =  '../fine_obstacle/prelu/p_02/'
        # '../MultiNodeAttn_HH/'
    # os.chdir('/home/serene/Documents/KITTIData/GT')

    datasets = [i for i in [0,1,2,3,4]]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)
    # datasets = [0]
    # args.leaveDataset = 0

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size, args.seq_length + 1, datasets, forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    # log_directory = './log/world_data/normalized_01/'
    log_directory = base_dir+'log/'
    log_directory += str(args.leaveDataset) + '/'
    log_directory += 'log_attention/'

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    # save_directory = './save/world_data/normalized_01/'
    save_directory = base_dir+'save/'
    save_directory += str(args.leaveDataset)+'/'
    save_directory += 'save_attention/'

    # log RELU parameter
    param_log_dir = save_directory + 'param_log.txt'
    param_log = open(param_log_dir , 'w')

    # Open the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar')

    # Initialize net
    net = SRNN(args)
    # srnn_model = SRNN(args)

    # net = torch.nn.DataParallel(srnn_model)

    # CUDA_VISIBLE_DEVICES = 1
    net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)

    # learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0
    # start_epoch = 0

    # if args.leaveDataset == 3:
    # start_epoch =checkpoint_path(0)

    # ckp = torch.load(checkpoint_path(3))
    # net.load_state_dict(ckp['state_dict'])
    # optimizer.load_state_dict(ckp['optimizer_state_dict'])

    # ckp = torch.load(checkpoint_path(4))
    # net.load_state_dict(ckp['state_dict'])
    # optimizer.load_state_dict(ckp['optimizer_state_dict'])
    # last_epoch = ckp['epoch']
    # rang = range(last_epoch, args.num_epochs, 1)
    rang = range(args.num_epochs)

    # Training
    for epoch in rang:
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0
        # if args.leaveDataset == 2 and epoch == 0:
        #     dataloader.num_batches = dataloader.num_batches + start_epoch
        #     epoch += start_epoch

        # For each batch

        stateful = True  # flag that controls transmission of previous hidden states to current hidden states vectors.
        # Part of statefulness

        for batch in range( dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True) ## shuffling input + stateless lstm

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]],d, args.distance_thresh)

                nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() #

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

                ## Modification : reset hidden and cell states only once after every batch ; keeping states updated during sequences 31st JAN

                # Define hidden states
                numNodes = nodes.size()[1]
                numObsNodes = obsNodes.size()[1]

                # if not stateful:
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                ## new update : 25th JAN , let the hidden state transition begin with negative ones
                ## such initialization did not lead to any new learning results, network converged quickly and loss did not decrease below
                # -6
                hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()

                cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                #  _ = \
                outputs, h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn, o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn, _ = \
                    net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs,
                    obsNodes[:args.seq_length], obsEdges[:args.seq_length], obsNodesPresent[:-1],
                    obsEdgesPresent[:-1]
                        ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs, cell_states_obs_node_RNNs,
                    cell_states_obs_edge_RNNs)

                # # else:
                # #     if len(nodes) == len(hidden_states_node_RNNs): # no additional nodes introduced in graph
                # #         hidden_states_node_RNNs = Variable(h_node_rnn).cuda()
                # #         cell_states_node_RNNs = Variable(cell_node_rnn).cuda()
                # #     else: # for now number of nodes is only increasing in time as new pedestrians are detected in the scene
                # #         pad_size = len(nodes) - len(hidden_states_node_RNNs)
                #         cell_states_node_RNNs = Variable(np.pad(cell_node_rnn, pad_size)).cuda()
                # if sequence > 0:
                #     hidden_states_node_RNNs = Variable(h_node_rnn).resize(hidden_states_node_RNNs.cpu().size())
                #     hidden_states_edge_RNNs = h_edge_rnn
                #     cell_states_node_RNNs = cell_node_rnn
                #     cell_states_edge_RNNs = cell_edge_rnn
                    # new_num_nodes = h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0]
                    # if h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0] >=1:
                    #     np.pad(hidden_states_node_RNNs.cpu() , new_num_nodes, mode='constant').cuda()

                    # hidden_states_obs_node_RNNs = o_h_node_rnn
                    # hidden_states_obs_edge_RNNs = o_h_edge_rnn
                    # cell_states_obs_node_RNNs = o_cell_node_rnn
                    # cell_states_obs_edge_RNNs = o_cell_edge_rnn

                # Zero out the gradients

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]
                # Compute gradients
                loss.backward(retain_variables=True)
                param_log.write(str(net.alpha.data[0]) + '\n')

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print(
                '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch,
                                                                                    args.num_epochs * dataloader.num_batches,
                                                                                    epoch,
                                                                                    loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch)+','+str(loss_epoch)+',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0
        # dataloader.valid_num_batches = dataloader.valid_num_batches + start_epoch
        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)## stateless lstm without shuffling

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]], d, args.distance_thresh)

                nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() #

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()
                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                numObsNodes = obsNodes.size()[1]
                hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()

                cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()
                #

                outputs,  h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn ,o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn,  _= net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1],
                                             edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs,
                                             cell_states_node_RNNs, cell_states_edge_RNNs
                                             , obsNodes[:args.seq_length], obsEdges[:args.seq_length],
                                             obsNodesPresent[:-1], obsEdgesPresent[:-1]
                                             ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs,
                                             cell_states_obs_node_RNNs, cell_states_obs_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)

                loss_batch += loss.data[0]

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + '\n')
        # Save the model after each epoch

        print('Saving model')
        torch.save({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, checkpoint_path(epoch))

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
    param_log.close()
Ejemplo n.º 8
0
def main():
    # os.chdir('/home/serene/Documents/KITTIData/GT/')
    # os.chdir('/home/siri0005/copy_srnn_pytorch/srnn-pytorch-master/')#/srnn-pytorch-master

    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    H_path = ['./pedestrians/ewap_dataset/seq_eth/H.txt',
              './pedestrians/ewap_dataset/seq_hotel/H.txt',
              './pedestrians/ucy_crowd/data_zara01/H.txt',
              './pedestrians/ucy_crowd/data_zara02/H.txt',
              './pedestrians/ucy_crowd/data_students03/H.txt']

    # ,1,2,3,
    base_dir = '../fine_obstacle/prelu/p_02/'
    for i in [4]:#,1,2,3,4
        avg = []
        ade = []
        with open(base_dir+'log/{0}/log_attention/val.txt'.format(i)) as val_f:
            if i == 1:
                e = 99
            else:
                best_val = val_f.readline()
                [e, val] = best_val.split(',')

        parser = argparse.ArgumentParser()
        # Observed length of the trajectory parameter
        parser.add_argument('--obs_length', type=int, default=8,
                            help='Observed length of the trajectory')
        # Predicted length of the trajectory parameter
        parser.add_argument('--pred_length', type=int, default=12,
                            help='Predicted length of the trajectory')
        # Test dataset
        parser.add_argument('--test_dataset', type=int, default=i,
                            help='Dataset to be tested on')

        # Model to be loaded
        parser.add_argument('--epoch', type=int, default=e,
                            help='Epoch of model to be loaded')

        # Parse the parameters
        sample_args = parser.parse_args()

        # Save directory
        save_directory = base_dir+'save/{0}/save_attention'.format(i)
        plot_directory = base_dir +  '/selected_plots/' #'plot_1/'
     

        # Define the path for the config file for saved args
        with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
            saved_args = pickle.load(f)

        # Initialize net
        net = SRNN(saved_args, sample_args.test_dataset, True)
        net.cuda()

        ## TODO: visualize trained weights
        # plt.imshow(net.humanNodeRNN.edge_embed.weight)
        # plt.colorbar()
        # plt.show()
        checkpoint_path = os.path.join(save_directory, 'srnn_model_'+str(sample_args.epoch)+'.tar')

        if os.path.isfile(checkpoint_path):
            print('Loading checkpoint')
            checkpoint = torch.load(checkpoint_path)
            # model_iteration = checkpoint['iteration']
            model_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            print('Loaded checkpoint at {}'.format(model_epoch))

        H_mat = np.loadtxt(H_path[i])

        avg = []
        ade = []
        # Dataset to get data from
        dataset = [sample_args.test_dataset]
        sample_ade_error_arr = []
        sample_fde_error_arr = []
        num_nodes = 0
        inner_num_nodes_1= 0
        inner_num_nodes_2= 0
        ade_sum = 0
        fde_sum = 0

        dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

        dataloader.reset_batch_pointer()

        # Construct the ST-graph object
        stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

        results = []

        # Variable to maintain total error
        total_error = 0
        final_error = 0
    
        # TRY this code version
        for batch in range(dataloader.num_batches):
            sample_fde_error = []
            sample_ade_error = []
            running_time_sample = []
            c = 0
            x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

            # Construct ST graph
            stgraph.readGraph(x, ds_ptr=d, threshold=1.0)

            nodes, edges, nodesPresent, edgesPresent, obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence()
            nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()
            edges = Variable(torch.from_numpy(edges).float(), volatile=True).cuda()

            obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
            obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

            # Separate out the observed part of the trajectory
            obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:sample_args.obs_length], edges[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], edgesPresent[:sample_args.obs_length]
            # Separate out the observed obstacles in a given sequence
            obsnodes_v, obsEdges_v, obsNodesPresent_v, obsEdgesPresent_v = obsNodes[:sample_args.obs_length], obsEdges[:sample_args.obs_length], obsNodesPresent[:sample_args.obs_length], obsEdgesPresent[:sample_args.obs_length]

            # if c == 0:
            # num_nodes += np.shape(nodes)[1]

            for c in range(10):
                num_nodes += np.shape(nodes)[1]
                start = time.time()
                # Sample function
                ret_nodes, ret_attn = sample(obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent, obsnodes_v,
                                             obsEdges_v, obsNodesPresent_v,obsEdgesPresent_v, sample_args, net, nodes, edges, nodesPresent)
                end = time.time()
                running_time_sample.append((end-start))
                # print('One-time Sampling took = ', (end - start), ' seconds.')

                # Compute mean and final displacement error
                total_error , _ = get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                              nodesPresent[sample_args.obs_length - 1],
                                              nodesPresent[sample_args.obs_length:], H_mat, i)

                # print("ADE errors:", total_error)
                inner_num_nodes_1 += _
                sample_ade_error.append(total_error)
                

                final_error , _ = get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                               nodesPresent[sample_args.obs_length - 1],
                                               nodesPresent[sample_args.obs_length:], H_mat, i)
                

                # print("final errors:", final_error)

                sample_fde_error.append(final_error)
               
                results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length, ret_attn, frameIDs))
       
                stgraph.reset()
            

            sample_ade_error_arr.append(np.sum(sample_ade_error))
            sample_fde_error_arr.append(np.sum(sample_fde_error))

            sample_ade_error = np.sum(sample_ade_error, 0)
        
            if len(sample_ade_error):
                # sample_ade_error /= 10
                sample_ade_error = torch.min(sample_ade_error)
                ade_sum += sample_ade_error
                ade.append(ade_sum) 
      
            # for non-rectangular tensors
            for (e, idx) in zip(sample_fde_error , range(len(sample_fde_error))):
                if int(len(e)) > 0 :
                    l = int(len(e))
                    sample_fde_error[idx] = np.sum(e) #/l
                else:
                    del sample_fde_error[idx]

      
            print(sample_fde_error)
            if (np.ndim(sample_fde_error) == 1 and len(sample_fde_error)) or \
                (np.ndim(sample_fde_error) > 1 and np.all([True for x in sample_fde_error if len(x) > 0] == True)):
       
                sample_fde_error = np.min(sample_fde_error)
                fde_sum += sample_fde_error
                avg.append(fde_sum)

            with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)

        print('SUMMARY **************************//')

        print('One-time Sampling took = ', np.average(running_time_sample), ' seconds.')
        print(np.sum(ade) , '   ' , np.sum(avg))
        print('average ADE', np.sum(ade) / (sample_args.pred_length * num_nodes))#
        print('average FDE', np.sum(avg) / (num_nodes*10))#
       
        with open(os.path.join(save_directory, 'sampling_results.txt'), 'wb') as o:
            np.savetxt(os.path.join(save_directory, 'sampling_results.txt'), (np.sum(ade) / (sample_args.pred_length * num_nodes),
                        np.sum(avg) / inner_num_nodes_1))
Ejemplo n.º 9
0
def train(args):
    print("INPUT SEQUENCE LENGTH: {}".format(args.seq_length))
    print("OUTPUT SEQUENCE LENGTH: {}".format(args.pred_length))
    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            forcePreProcess=False)
    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = "../log/"

    # Logging file
    log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w")
    log_file = open(os.path.join(log_directory, "val.txt"), "w")

    # Save directory
    save_directory = "../../save_weight/"

    # Open the configuration file # 현재 argument 세팅 저장.
    with open(os.path.join(save_directory, "config.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, "4Dgraph.S-{}.P-{}.srnn_model_"\
                            .format(args.seq_length, args.pred_length)\
                            + str(x) + ".tar")

    # Initialize net
    net = SRNN(args)
    if args.use_cuda:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5)

    # learning_rate = args.learning_rate
    logging.info("Training begin")
    best_val_loss = 100
    best_epoch = 0
    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)  # Initialization
        loss_epoch = 0

        # For each batch
        # dataloader.num_batches = 10. 1 epoch have 10 batches
        for batch in range(dataloader.num_batches):
            start = time.time()
            # Get batch data, mini-batch
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(
                    dataloader.batch_size):  #미니 배치에 있는 각 sequence 데이터에 대한 처리.
                # Construct the graph for the current sequence in {nodes, edges}
                stgraph.readGraph([x[sequence]])
                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )  #미니 배치에 있는 각 sequence의 graph 정보.

                ##### Convert to cuda variables #####
                nodes = Variable(torch.from_numpy(nodes).float())
                # nodes[0] represent all the person(object)'s corrdinate show up in frame 0.
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]  #numNodes

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()

                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(  #NOTE: 0 for peds., 1 for Bic., 2 for Veh.
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                # Zero out the gradients // Initialization Step
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.item()
                # embed()
                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size  ##### NOTE: Expected Loss; E[L]
            loss_epoch += loss_batch

            logging.info(
                "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}".
                format(
                    epoch * dataloader.num_batches + batch,
                    args.num_epochs * dataloader.num_batches,
                    epoch,
                    loss_batch,
                    end - start,
                ))
        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",")

        #####################     Validation Part     #####################
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data

            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size  #### NOTE: Expected Loss; E[L]
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        logging.info("(epoch {}), valid_loss = {:.3f}".format(
            epoch, loss_epoch))
        logging.info("Best epoch {}, Best validation loss {}".format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + "\n")

        # Save the model after each epoch
        logging.info("Saving model")
        torch.save(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            checkpoint_path(epoch),
        )

    # Record the best epoch and best validation loss overall
    logging.info("Best epoch {}, Best validation loss {}".format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + "," + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Ejemplo n.º 10
0
def test(args):
    data = np.load('./data_load2.npz')
    keys = list(data.keys())
    keys_train = keys[0:-3]
    keys_eval = keys[-2:]
    dataset_eval = TrajectoryDataset(data,
                                     args.seq_length - args.pred_length + 1,
                                     args.pred_length, keys_eval)
    dataloader = NewDataLoader(dataset_eval, batch_size=args.batch_size)

    stgraph = ST_GRAPH(1, args.seq_length + 1)

    net = torch.load(args.model, map_location=torch.device('cpu'))
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    dataloader.reset_batch_pointer()

    # For each batch
    for batch in range(dataloader.num_batches):
        if batch >= 10:
            break
        start = time.time()

        # Get batch data
        # x, _, _, d = dataloader.next_batch(randomUpdate=True)
        x = dataloader.next_batch()

        # Loss for this batch
        # loss_batch = 0
        batch_loss = Variable(torch.zeros(1))
        # batch_loss = batch_loss.cuda()

        # For each sequence in the batch
        for sequence in range(dataloader.batch_size):
            # Construct the graph for the current sequence
            stgraph.readGraph([x[sequence]])

            nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

            # Convert to cuda variables
            # nodes = Variable(torch.from_numpy(nodes).float()).cuda()
            # edges = Variable(torch.from_numpy(edges).float()).cuda()
            nodes = Variable(torch.from_numpy(nodes).float())
            edges = Variable(torch.from_numpy(edges).float())

            # Define hidden states
            numNodes = nodes.size()[1]
            # hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
            # hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()
            hidden_states_node_RNNs = Variable(
                torch.zeros(numNodes, args.human_node_rnn_size))
            hidden_states_edge_RNNs = Variable(
                torch.zeros(numNodes * numNodes,
                            args.human_human_edge_rnn_size))

            # cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
            # cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()
            cell_states_node_RNNs = Variable(
                torch.zeros(numNodes, args.human_node_rnn_size))
            cell_states_edge_RNNs = Variable(
                torch.zeros(numNodes * numNodes,
                            args.human_human_edge_rnn_size))

            # Zero out the gradients
            net.zero_grad()

            # Forward prop
            outputs, _, _, _, _, _ = net(
                nodes[:args.seq_length], edges[:args.seq_length],
                nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs,
                hidden_states_edge_RNNs, cell_states_node_RNNs,
                cell_states_edge_RNNs)

            # print(outputs.shape)
            # print(nodes.shape)
            # print(nodesPresent)
            # print('----------------')
            # raise KeyError

            # Compute loss
            # loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)
            loss = net.get_square_loss(outputs, nodes[1:])
            batch_loss = batch_loss + loss
            optimizer.step()
            # print(loss)
            stgraph.reset()
        end = time.time()
        batch_loss = batch_loss / dataloader.batch_size
        # loss_batch = loss_batch / dataloader.batch_size
        # loss_epoch += loss_batch
        loss_batch = batch_loss.item()

        print('{}/{} , test_loss = {:.3f}, time/batch = {:.3f}'.format(
            batch, dataloader.num_batches, loss_batch, end - start))
Ejemplo n.º 11
0
def train(args):
    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            forcePreProcess=False)
    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = os.getcwd() + "\\log\\"
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)

    # Logging file
    log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w+")
    log_file = open(os.path.join(log_directory, "val.txt"), "w")

    # Save directory
    save_directory = os.getcwd() + "\\save\\"
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Open the configuration file
    with open(os.path.join(save_directory, "config.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, "srnn_model_" + str(x) + ".tar")

    # Initialize net
    net = SRNN(args)
    net.load_state_dict(torch.load('modelref/srnn_model_271.tar'),
                        strict=False)
    if args.use_cuda:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5)

    # learning_rate = args.learning_rate
    logging.info("Training begin")
    best_val_loss = 100
    best_epoch = 0

    global plotter
    plotter = VisdomLinePlotter(env_name='main')

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        # dataloader.num_batches = 10. 1 epoch have 10 batches
        for batch in range(dataloader.num_batches):
            start = time.time()
            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])
                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )
                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                # nodes[0] represent all the person's corrdinate show up in  frame 0.
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()

                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()
                # Forward prop
                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.item()
                # embed()
                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            logging.info(
                "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}".
                format(
                    epoch * dataloader.num_batches + batch,
                    args.num_epochs * dataloader.num_batches,
                    epoch,
                    loss_batch,
                    end - start,
                ))
        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        plotter.plot('loss', 'train', 'Class Loss', epoch, loss_epoch)
        # Log it
        log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",")

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data

            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        plotter.plot('loss', 'val', 'Class Loss', epoch, loss_epoch)

        # Record best epoch and best validation loss
        logging.info("(epoch {}), valid_loss = {:.3f}".format(
            epoch, loss_epoch))
        logging.info("Best epoch {}, Best validation loss {}".format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + "\n")

        # Save the model after each epoch
        logging.info("Saving model")
        torch.save(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            checkpoint_path(epoch),
        )

    # Record the best epoch and best validation loss overall
    logging.info("Best epoch {}, Best validation loss {}".format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + "," + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Ejemplo n.º 12
0
def main():
    # os.chdir('/home/serene/Documents/KITTIData/GT/')
    # os.chdir('/home/siri0005/srnn-pytorch-master/')#/srnn-pytorch-master

    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    H_path = ['././pedestrians/ewap_dataset/seq_eth/H.txt',
              '././pedestrians/ewap_dataset/seq_hotel/H.txt',
              '././pedestrians/ucy_crowd/data_zara01/H.txt',
              '././pedestrians/ucy_crowd/data_zara02/H.txt',
              '././pedestrians/ucy_crowd/data_students03/H.txt']
    avg = []
    ade = []
    # ,1,2,3,
    # base_dir = '/home/serene/Downloads/srnn-pytorch-master/' #'../MultiNodeAttn_HH/' #'../fine_obstacle/prelu/p_02/'
    base_dir = '../MultiNodeAttn_HH/' #'/home/serene/Downloads/ablation/'
    for i in [1]:
        with open(base_dir+'log/{0}/log_attention/val.txt'.format(i)) as val_f:
            best_val = val_f.readline()
            # e = 45
            [e, val] = best_val.split(',')

        parser = argparse.ArgumentParser()
        # Observed length of the trajectory parameter
        parser.add_argument('--obs_length', type=int, default=8,
                            help='Observed length of the trajectory')
        # Predicted length of the trajectory parameter
        parser.add_argument('--pred_length', type=int, default=12,
                            help='Predicted length of the trajectory')
        # Test dataset
        parser.add_argument('--test_dataset', type=int, default=i,
                            help='Dataset to be tested on')

        # Model to be loaded
        parser.add_argument('--epoch', type=int, default=e,
                            help='Epoch of model to be loaded')

        parser.add_argument('--use_cuda', action="store_true", default=True,
                            help="Use GPU or CPU")

        # Parse the parameters
        sample_args = parser.parse_args()

        # Save directory
        save_directory = base_dir+'save/{0}/save_attention'.format(i)
        plot_directory = base_dir +  '/selected_plots/' #'plot_1/'
        # save_directory = './srnn-pytorch-master/fine_obstacle/save/{0}/save_attention'.format(i)
        #'/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master/save/pixel_data/100e/'
        #'/home/serene/Documents/InVehicleCamera/save_kitti/'

        # save_directory += str(sample_args.test_dataset)+'/'
        # save_directory += 'save_attention'

        # Define the path for the config file for saved args
        with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
            saved_args = pickle.load(f)

        # Initialize net
        net = SRNN(saved_args, True)
        net.cuda()

        ## TODO: visualize trained weights
        # plt.imshow(net.humanNodeRNN.edge_embed.weight)
        # plt.colorbar()
        # plt.show()
        checkpoint_path = os.path.join(save_directory, 'srnn_model_'+str(sample_args.epoch)+'.tar')

        if os.path.isfile(checkpoint_path):
            print('Loading checkpoint')
            checkpoint = torch.load(checkpoint_path)
            # model_iteration = checkpoint['iteration']
            model_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            print('Loaded checkpoint at {}'.format(model_epoch))

        H_mat = np.loadtxt(H_path[i])

        avg = []
        ade = []
        # Dataset to get data from
        dataset = [sample_args.test_dataset]
        for c in range(30):
            dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

            dataloader.reset_batch_pointer()

            # Construct the ST-graph object
            stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

            results = []

            # Variable to maintain total error
            total_error = 0
            final_error = 0
            minimum = 1000
            min_final = 1000
            for batch in range(dataloader.num_batches):
                start = time.time()

                # Get the next batch
                x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

                # Construct ST graph
                stgraph.readGraph(x, ds_ptr=d,threshold=1.0)

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()
                #obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()
                edges = Variable(torch.from_numpy(edges).float(), volatile=True).cuda()

                # obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                # obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()
                # NOTE: obs_ : observed
                # Separate out the observed part of the trajectory
                obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:sample_args.obs_length], edges[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], edgesPresent[:sample_args.obs_length]

                # Separate out the observed obstacles in a given sequence
                # obsnodes_v, obsEdges_v , obsNodesPresent_v , obsEdgesPresent_v = obsNodes[:sample_args.obs_length], obsEdges[:sample_args.obs_length], obsNodesPresent[:sample_args.obs_length], obsEdgesPresent[:sample_args.obs_length]

                # Sample function
                ret_nodes, ret_attn = sample(obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent,sample_args, net, nodes, edges, nodesPresent)
                    # , obsnodes_v , obsEdges_v, obsNodesPresent_v,
                    #   obsEdgesPresent_v ,  )

                # Compute mean and final displacement error
                total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                              nodesPresent[sample_args.obs_length - 1],
                                              nodesPresent[sample_args.obs_length:], H_mat, i)

                final_error += get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                               nodesPresent[sample_args.obs_length - 1],
                                               nodesPresent[sample_args.obs_length:], H_mat, i)

                # if total_error < minimum:
                #     minimum = total_error
                # if final_error < min_final:
                #     min_final = final_error

                end = time.time()

                # Store results
                results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length, ret_attn, frameIDs))
                # zfill = 3

                # for i in range(len(results)):
                #     skip = str(int(results[i][5][0][8])).zfill(zfill)
                #     # img_file = '/home/serene/Documents/video/hotel/frame-{0}.jpg'.format(skip)
                #     # for j in range(20):
                #     #     if i == 40:
                #
                #     img_file = '/home/serene/Documents/copy_srnn_pytorch/data/ucy/zara/zara.png'
                #     name = plot_directory  + 'sequence_zara' + str(i)  # /pedestrian_1
                #     # for k in range(20):
                #     vis.plot_trajectories(results[i][0], results[i][1], results[i][2], results[i][3], name,
                #                       plot_directory, img_file, 1)
                #     if int(skip) >= 999 and zfill < 4:
                #         zfill = zfill + 1
                #     elif int(skip) >= 9999 and zfill < 5:
                #         zfill = zfill + 1

                # Reset the ST graph
                stgraph.reset()

            print('Total mean error of the model is ', total_error / dataloader.num_batches)
            print('Total final error of the model is ', final_error / dataloader.num_batches)
            ade.append(total_error / dataloader.num_batches)
            avg.append(final_error / dataloader.num_batches)
            print('Saving results')
            with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)

        print('average FDE', np.average(avg))
        print('average ADE', np.average(ade))

        with open(os.path.join(save_directory, 'sampling_results.txt'), 'wb') as o:
            np.savetxt(os.path.join(save_directory, 'sampling_results.txt'), (ade, avg) , fmt='%.03e')
Ejemplo n.º 13
0
def train(args):

    # Construct the DataLoader object
    ## args: (batch_size=50, seq_length=5, datasets=[0, 1, 2, 3, 4, 5, 6], forcePreProcess=False, infer=False)
    dataloader = DataLoader(args.batch_size, args.seq_length + 1, args.train_dataset, forcePreProcess=True) ##** not sure why seq_length+1

    # Construct the ST-graph object
    ## args: (batch_size=50, seq_length=5)
    stgraph = ST_GRAPH(1, args.seq_length + 1)  ##**not sure why batch_size=1 and seq_length+1

    # Log directory
    log_directory = 'log/trainedOn_'+ str(args.train_dataset)
    if not os.path.exists(log_directory):
            os.makedirs(log_directory)

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory for saving the model
    save_directory = 'save/trainedOn_'+str(args.train_dataset)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Open the configuration file
    ## store arguments from parser
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file, i.e. model after the particular epoch
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar')

    # Initialize net
    net = SRNN(args)
    if args.use_cuda:        
        net = net.cuda()

    # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)
    optimizer = torch.optim.Adagrad(net.parameters())

    learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            ## Format:
            ## x_batch:     input sequence of length self.seq_length
            ## y_batch:     output seq of same length shifted y 1 step in time
            ## frame_batch: frame IDs in the batch
            ## d:           current position of dataset pointer (points to the next batch to be loaded)
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print(
                '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch,
                                                                                    args.num_epochs * dataloader.num_batches,
                                                                                    epoch,
                                                                                    loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch)+','+str(loss_epoch)+',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1],
                                             hidden_states_node_RNNs, hidden_states_edge_RNNs,
                                             cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)

                loss_batch += loss.data.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches


        #Saving the model\
        if loss_epoch < best_val_loss or args.save_every:
            print('Saving model')
            torch.save({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch)) 

        #save_best_model overwriting the earlier file
        #if loss_epoch < best_val_loss:


        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch
                      
        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch)+'\n')

        

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch)+','+str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()