Example #1
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length',
                        type=int,
                        default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length',
                        type=int,
                        default=12,
                        help='Predicted length of the trajectory')

    # Train Dataset
    # Use like:
    # python transpose_inrange.py --train_dataset index_1 index_2 ...
    parser.add_argument(
        '-l',
        '--train_dataset',
        nargs='+',
        help=
        '<Required> training dataset(s) the model is trained on: --train_dataset index_1 index_2 ...',
        default=[0, 1, 2, 4],
        type=int)

    # Test dataset
    parser.add_argument('--test_dataset',
                        type=int,
                        default=3,
                        help='Dataset to be tested on')

    # Model to be loaded
    parser.add_argument('--epoch',
                        type=int,
                        default=26,
                        help='Epoch of model to be loaded')

    # Use GPU or not
    parser.add_argument('--use_cuda',
                        action="store_true",
                        default=False,
                        help="Use GPU or CPU")

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    load_directory = 'save/'
    load_directory += 'trainedOn_' + str(sample_args.train_dataset)

    # Define the path for the config file for saved args
    ## Arguments of parser while traning
    with open(os.path.join(load_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    if saved_args.use_cuda:
        net = net.cuda()

    checkpoint_path = os.path.join(
        load_directory, 'srnn_model_' + str(sample_args.epoch) + '.tar')

    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at {}'.format(model_epoch))

    # Dataset to get data from
    dataset = [sample_args.test_dataset]

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            dataset,
                            True,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    total_error = 0
    final_error = 0

    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get the next batch
        x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

        # Construct ST graph
        stgraph.readGraph(x)

        nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

        # Convert to cuda variables
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True)
        edges = Variable(torch.from_numpy(edges).float(), volatile=True)
        if saved_args.use_cuda:
            nodes = nodes.cuda()
            edges = edges.cuda()

        # Separate out the observed part of the trajectory
        obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:
                                                                         sample_args
                                                                         .
                                                                         obs_length], edges[:
                                                                                            sample_args
                                                                                            .
                                                                                            obs_length], nodesPresent[:
                                                                                                                      sample_args
                                                                                                                      .
                                                                                                                      obs_length], edgesPresent[:
                                                                                                                                                sample_args
                                                                                                                                                .
                                                                                                                                                obs_length]

        # Sample function
        ret_nodes, ret_attn, ret_new_attn = sample(obs_nodes, obs_edges,
                                                   obs_nodesPresent,
                                                   obs_edgesPresent,
                                                   sample_args, net, nodes,
                                                   edges, nodesPresent)

        # Compute mean and final displacement error
        total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data,
                                      nodes[sample_args.obs_length:].data,
                                      nodesPresent[sample_args.obs_length - 1],
                                      nodesPresent[sample_args.obs_length:],
                                      saved_args.use_cuda)
        final_error += get_final_error(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:])

        end = time.time()

        print('Processed trajectory number : ', batch, 'out of',
              dataloader.num_batches, 'trajectories in time', end - start)

        # Store results
        if saved_args.use_cuda:
            results.append(
                (nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(),
                 nodesPresent, sample_args.obs_length, ret_attn, ret_new_attn,
                 frameIDs))
        else:
            results.append(
                (nodes.data.numpy(), ret_nodes.data.numpy(), nodesPresent,
                 sample_args.obs_length, ret_attn, ret_new_attn, frameIDs))

        # Reset the ST graph
        stgraph.reset()

    print('Total mean error of the model is ',
          total_error / dataloader.num_batches)
    print('Total final error of the model is ',
          final_error / dataloader.num_batches)

    print('Saving results')
    save_directory = load_directory + '/testedOn_' + str(
        sample_args.test_dataset)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
        pickle.dump(results, f)
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument("--obs_length",
                        type=int,
                        default=4,
                        help="Observed length of the trajectory")
    # Predicted length of the trajectory parameter
    parser.add_argument("--pred_length",
                        type=int,
                        default=6,
                        help="Predicted length of the trajectory")
    # Model to be loaded
    parser.add_argument("--epoch",
                        type=int,
                        default=233,
                        help="Epoch of model to be loaded")

    # Use GPU or not
    parser.add_argument("--use_cuda",
                        action="store_true",
                        default=True,
                        help="Use GPU or CPU")

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = os.getcwd() + "\\save\\"
    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, "config.pkl"), "rb") as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    if saved_args.use_cuda:
        net = net.cuda()

    checkpoint_path = os.path.join(
        save_directory, "srnn_model_" + str(sample_args.epoch) + ".tar")

    if os.path.isfile(checkpoint_path):
        print("Loading checkpoint")
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint["epoch"]
        net.load_state_dict(checkpoint["state_dict"])
        print("Loaded checkpoint at {}".format(model_epoch))

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    results = []

    # Variable to maintain total error
    # total_error = 0
    # final_error = 0
    avg_ped_error = 0
    final_ped_error = 0
    avg_bic_error = 0
    final_bic_error = 0
    avg_car_error = 0
    final_car_error = 0

    for batch in range(dataloader.num_batches):
        start = time.time()

        # Get the next batch
        x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

        # Construct ST graph
        stgraph.readGraph(x)

        nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

        # Convert to cuda variables
        nodes = Variable(torch.from_numpy(nodes).float(), volatile=True)
        edges = Variable(torch.from_numpy(edges).float(), volatile=True)
        if saved_args.use_cuda:
            nodes = nodes.cuda()
            edges = edges.cuda()

        # Separate out the observed part of the trajectory
        obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = (
            nodes[:sample_args.obs_length],
            edges[:sample_args.obs_length],
            nodesPresent[:sample_args.obs_length],
            edgesPresent[:sample_args.obs_length],
        )

        # Sample function
        ret_nodes, ret_attn = sample(
            obs_nodes,
            obs_edges,
            obs_nodesPresent,
            obs_edgesPresent,
            sample_args,
            net,
            nodes,
            edges,
            nodesPresent,
        )

        # Compute mean and final displacement error
        """
        total_error += get_mean_error(
            ret_nodes[sample_args.obs_length :].data,
            nodes[sample_args.obs_length :].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length :],
            saved_args.use_cuda,
        )
        final_error += get_final_error(
            ret_nodes[sample_args.obs_length :].data,
            nodes[sample_args.obs_length :].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length :],
        )
        """
        avg_ped_error_delta, avg_bic_error_delta, avg_car_error_delta = get_mean_error_separately(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:],
            saved_args.use_cuda,
        )
        avg_ped_error += avg_ped_error_delta
        avg_bic_error += avg_bic_error_delta
        avg_car_error += avg_car_error_delta

        final_ped_error_delta, final_bic_error_delta, final_car_error_delta = get_final_error_separately(
            ret_nodes[sample_args.obs_length:].data,
            nodes[sample_args.obs_length:].data,
            nodesPresent[sample_args.obs_length - 1],
            nodesPresent[sample_args.obs_length:],
        )
        final_ped_error += final_ped_error_delta
        final_bic_error += final_bic_error_delta
        final_car_error += final_car_error_delta

        end = time.time()

        print(
            "Processed trajectory number : ",
            batch,
            "out of",
            dataloader.num_batches,
            "trajectories in time",
            end - start,
        )
        if saved_args.use_cuda:
            results.append((
                nodes.data.cpu().numpy(),
                ret_nodes.data.cpu().numpy(),
                nodesPresent,
                sample_args.obs_length,
                ret_attn,
                frameIDs,
            ))
        else:
            results.append((
                nodes.data.numpy(),
                ret_nodes.data.numpy(),
                nodesPresent,
                sample_args.obs_length,
                ret_attn,
                frameIDs,
            ))

        # Reset the ST graph
        stgraph.reset()

    # print("Total mean error of the model is ", total_error / dataloader.num_batches)
    # print(
    #    "Total final error of the model is ", final_error / dataloader.num_batches
    # )  # num_batches = 10
    print(
        "AVG disp error:     singapore-onenorth: {}       boston-seaport: {}        singapore-queenstown:{}"
        .format(
            avg_ped_error / dataloader.num_batches,
            avg_bic_error / dataloader.num_batches,
            avg_car_error / dataloader.num_batches,
        ))

    print("total average error:    {}".format(
        (avg_ped_error + avg_bic_error + avg_car_error) /
        dataloader.num_batches / 3))

    print(
        "Final disp error:   singapore-onenorth: {}       boston-seaport: {}        singapore-queenstown:{}"
        .format(
            final_ped_error / dataloader.num_batches,
            final_bic_error / dataloader.num_batches,
            final_car_error / dataloader.num_batches,
        ))
    print("total final error:    {}".format(
        (final_ped_error + final_bic_error + final_car_error) /
        dataloader.num_batches / 3))

    print("Saving results")
    with open(os.path.join(save_directory, "results.pkl"), "wb") as f:
        pickle.dump(results, f)
Example #3
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length',
                        type=int,
                        default=8,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length',
                        type=int,
                        default=12,
                        help='Predicted length of the trajectory')
    # Test dataset
    parser.add_argument('--test_dataset',
                        type=int,
                        default=4,
                        help='Dataset to be tested on')

    parser.add_argument('--sample_dataset',
                        type=int,
                        default=4,
                        help='Dataset to be sampled on')

    # Model to be loaded
    parser.add_argument('--epoch',
                        type=int,
                        default=133,
                        help='Epoch of model to be loaded')
    #[109,149,124,80,128]
    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = '/home/hesl/PycharmProjects/srnn-pytorch/save/FixedPixel_150epochs_batchsize24_Pruned/'
    save_directory += str(sample_args.test_dataset) + '/'
    save_directory += 'save_attention'

    ouput_directory = '/home/hesl/PycharmProjects/srnn-pytorch/save/'
    #ouput_directory+= str(sample_args.test_dataset) + '/'
    ouput_directory = save_directory

    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize net
    net = SRNN(saved_args, True)
    net.cuda()
    #net.forward()

    checkpoint_path = os.path.join(
        save_directory, 'srnn_model_' + str(sample_args.epoch) + '.tar')

    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at {}'.format(model_epoch))

    # homography
    H = np.loadtxt(H_path[sample_args.sample_dataset])

    # Dataset to get data from
    dataset = [sample_args.test_dataset]
    dataset = [sample_args.sample_dataset]

    dataloader = DataLoader(1,
                            sample_args.pred_length + sample_args.obs_length,
                            dataset,
                            True,
                            infer=True)

    dataloader.reset_batch_pointer()

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    NumberofSampling = 10

    for i in range(NumberofSampling):

        results = []

        # Variable to maintain total error
        total_error = 0
        final_error = 0

        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get the next batch
            x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

            # Construct ST graph
            stgraph.readGraph(x)

            nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

            # Convert to cuda variables
            nodes = Variable(torch.from_numpy(nodes).float(),
                             volatile=True).cuda()
            edges = Variable(torch.from_numpy(edges).float(),
                             volatile=True).cuda()

            # Separate out the observed part of the trajectory
            obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:
                                                                             sample_args
                                                                             .
                                                                             obs_length], edges[:
                                                                                                sample_args
                                                                                                .
                                                                                                obs_length], nodesPresent[:
                                                                                                                          sample_args
                                                                                                                          .
                                                                                                                          obs_length], edgesPresent[:
                                                                                                                                                    sample_args
                                                                                                                                                    .
                                                                                                                                                    obs_length]

            # Sample function
            ret_nodes, ret_attn = sample(obs_nodes, obs_edges,
                                         obs_nodesPresent, obs_edgesPresent,
                                         sample_args, net, nodes, edges,
                                         nodesPresent)

            # Compute mean and final displacement error
            total_error += get_mean_error(
                ret_nodes[sample_args.obs_length:].data,
                nodes[sample_args.obs_length:].data,
                nodesPresent[sample_args.obs_length - 1],
                nodesPresent[sample_args.obs_length:], H,
                sample_args.sample_dataset)
            final_error += get_final_error(
                ret_nodes[sample_args.obs_length:].data,
                nodes[sample_args.obs_length:].data,
                nodesPresent[sample_args.obs_length - 1],
                nodesPresent[sample_args.obs_length:], H,
                sample_args.sample_dataset)

            end = time.time()

            print('Processed trajectory number : ', batch, 'out of',
                  dataloader.num_batches, 'trajectories in time', end - start)

            # Store results
            results.append(
                (nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(),
                 nodesPresent, sample_args.obs_length, ret_attn, frameIDs,
                 total_error / dataloader.num_batches,
                 final_error / dataloader.num_batches))

            # Reset the ST graph
            stgraph.reset()

        print('Total mean error of the model is ',
              total_error / dataloader.num_batches)
        print('Total final error of the model is ',
              final_error / dataloader.num_batches)

        current_mean_error = total_error / dataloader.num_batches
        current_final_error = final_error / dataloader.num_batches
        if i == 0:
            min_current_mean_error = current_mean_error
            min_current_final_error = current_final_error
            min_index = i
            print('Saving initial results on {}'.format(i))
            with open(os.path.join(ouput_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)
        else:
            if current_mean_error < min_current_mean_error:
                min_current_mean_error = current_mean_error
                min_current_final_error = current_final_error
                min_index = i
                print('Found Smaller Error on {}, Saving results'.format(i))
                print(
                    'Smaller current_mean_error"{} and current_final_error:{} and '
                    .format(current_mean_error, current_final_error))
                with open(os.path.join(ouput_directory, 'results.pkl'),
                          'wb') as f:
                    pickle.dump(results, f)

    print(
        'Minimum Total Mean Error is {} and Minimum Final Mean Error is {} on {}th Sampling'
        .format(min_current_mean_error, min_current_final_error, min_index))
Example #4
0
def main():
    # os.chdir('/home/serene/Documents/KITTIData/GT/')
    # os.chdir('/home/siri0005/copy_srnn_pytorch/srnn-pytorch-master/')#/srnn-pytorch-master

    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    H_path = ['./pedestrians/ewap_dataset/seq_eth/H.txt',
              './pedestrians/ewap_dataset/seq_hotel/H.txt',
              './pedestrians/ucy_crowd/data_zara01/H.txt',
              './pedestrians/ucy_crowd/data_zara02/H.txt',
              './pedestrians/ucy_crowd/data_students03/H.txt']

    # ,1,2,3,
    base_dir = '../fine_obstacle/prelu/p_02/'
    for i in [4]:#,1,2,3,4
        avg = []
        ade = []
        with open(base_dir+'log/{0}/log_attention/val.txt'.format(i)) as val_f:
            if i == 1:
                e = 99
            else:
                best_val = val_f.readline()
                [e, val] = best_val.split(',')

        parser = argparse.ArgumentParser()
        # Observed length of the trajectory parameter
        parser.add_argument('--obs_length', type=int, default=8,
                            help='Observed length of the trajectory')
        # Predicted length of the trajectory parameter
        parser.add_argument('--pred_length', type=int, default=12,
                            help='Predicted length of the trajectory')
        # Test dataset
        parser.add_argument('--test_dataset', type=int, default=i,
                            help='Dataset to be tested on')

        # Model to be loaded
        parser.add_argument('--epoch', type=int, default=e,
                            help='Epoch of model to be loaded')

        # Parse the parameters
        sample_args = parser.parse_args()

        # Save directory
        save_directory = base_dir+'save/{0}/save_attention'.format(i)
        plot_directory = base_dir +  '/selected_plots/' #'plot_1/'
     

        # Define the path for the config file for saved args
        with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
            saved_args = pickle.load(f)

        # Initialize net
        net = SRNN(saved_args, sample_args.test_dataset, True)
        net.cuda()

        ## TODO: visualize trained weights
        # plt.imshow(net.humanNodeRNN.edge_embed.weight)
        # plt.colorbar()
        # plt.show()
        checkpoint_path = os.path.join(save_directory, 'srnn_model_'+str(sample_args.epoch)+'.tar')

        if os.path.isfile(checkpoint_path):
            print('Loading checkpoint')
            checkpoint = torch.load(checkpoint_path)
            # model_iteration = checkpoint['iteration']
            model_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            print('Loaded checkpoint at {}'.format(model_epoch))

        H_mat = np.loadtxt(H_path[i])

        avg = []
        ade = []
        # Dataset to get data from
        dataset = [sample_args.test_dataset]
        sample_ade_error_arr = []
        sample_fde_error_arr = []
        num_nodes = 0
        inner_num_nodes_1= 0
        inner_num_nodes_2= 0
        ade_sum = 0
        fde_sum = 0

        dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

        dataloader.reset_batch_pointer()

        # Construct the ST-graph object
        stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

        results = []

        # Variable to maintain total error
        total_error = 0
        final_error = 0
    
        # TRY this code version
        for batch in range(dataloader.num_batches):
            sample_fde_error = []
            sample_ade_error = []
            running_time_sample = []
            c = 0
            x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

            # Construct ST graph
            stgraph.readGraph(x, ds_ptr=d, threshold=1.0)

            nodes, edges, nodesPresent, edgesPresent, obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence()
            nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()
            edges = Variable(torch.from_numpy(edges).float(), volatile=True).cuda()

            obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
            obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

            # Separate out the observed part of the trajectory
            obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:sample_args.obs_length], edges[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], edgesPresent[:sample_args.obs_length]
            # Separate out the observed obstacles in a given sequence
            obsnodes_v, obsEdges_v, obsNodesPresent_v, obsEdgesPresent_v = obsNodes[:sample_args.obs_length], obsEdges[:sample_args.obs_length], obsNodesPresent[:sample_args.obs_length], obsEdgesPresent[:sample_args.obs_length]

            # if c == 0:
            # num_nodes += np.shape(nodes)[1]

            for c in range(10):
                num_nodes += np.shape(nodes)[1]
                start = time.time()
                # Sample function
                ret_nodes, ret_attn = sample(obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent, obsnodes_v,
                                             obsEdges_v, obsNodesPresent_v,obsEdgesPresent_v, sample_args, net, nodes, edges, nodesPresent)
                end = time.time()
                running_time_sample.append((end-start))
                # print('One-time Sampling took = ', (end - start), ' seconds.')

                # Compute mean and final displacement error
                total_error , _ = get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                              nodesPresent[sample_args.obs_length - 1],
                                              nodesPresent[sample_args.obs_length:], H_mat, i)

                # print("ADE errors:", total_error)
                inner_num_nodes_1 += _
                sample_ade_error.append(total_error)
                

                final_error , _ = get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                               nodesPresent[sample_args.obs_length - 1],
                                               nodesPresent[sample_args.obs_length:], H_mat, i)
                

                # print("final errors:", final_error)

                sample_fde_error.append(final_error)
               
                results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length, ret_attn, frameIDs))
       
                stgraph.reset()
            

            sample_ade_error_arr.append(np.sum(sample_ade_error))
            sample_fde_error_arr.append(np.sum(sample_fde_error))

            sample_ade_error = np.sum(sample_ade_error, 0)
        
            if len(sample_ade_error):
                # sample_ade_error /= 10
                sample_ade_error = torch.min(sample_ade_error)
                ade_sum += sample_ade_error
                ade.append(ade_sum) 
      
            # for non-rectangular tensors
            for (e, idx) in zip(sample_fde_error , range(len(sample_fde_error))):
                if int(len(e)) > 0 :
                    l = int(len(e))
                    sample_fde_error[idx] = np.sum(e) #/l
                else:
                    del sample_fde_error[idx]

      
            print(sample_fde_error)
            if (np.ndim(sample_fde_error) == 1 and len(sample_fde_error)) or \
                (np.ndim(sample_fde_error) > 1 and np.all([True for x in sample_fde_error if len(x) > 0] == True)):
       
                sample_fde_error = np.min(sample_fde_error)
                fde_sum += sample_fde_error
                avg.append(fde_sum)

            with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)

        print('SUMMARY **************************//')

        print('One-time Sampling took = ', np.average(running_time_sample), ' seconds.')
        print(np.sum(ade) , '   ' , np.sum(avg))
        print('average ADE', np.sum(ade) / (sample_args.pred_length * num_nodes))#
        print('average FDE', np.sum(avg) / (num_nodes*10))#
       
        with open(os.path.join(save_directory, 'sampling_results.txt'), 'wb') as o:
            np.savetxt(os.path.join(save_directory, 'sampling_results.txt'), (np.sum(ade) / (sample_args.pred_length * num_nodes),
                        np.sum(avg) / inner_num_nodes_1))
Example #5
0
def main():

    parser = argparse.ArgumentParser()
    # Observed length of the trajectory parameter
    parser.add_argument('--obs_length',
                        type=int,
                        default=16,
                        help='Observed length of the trajectory')
    # Predicted length of the trajectory parameter
    parser.add_argument('--pred_length',
                        type=int,
                        default=8,
                        help='Predicted length of the trajectory')
    # Model to be loaded
    parser.add_argument('--epoch',
                        type=int,
                        default=1,
                        help='Epoch of model to be loaded')

    # Parse the parameters
    sample_args = parser.parse_args()

    # Save directory
    save_directory = '../save-beijing-16/'

    # Define the path for the config file for saved args
    with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)
        args = saved_args

    # Initialize net
    net = SRNN(saved_args)
    net.cuda()

    checkpoint_path = os.path.join(
        save_directory, 'srnn_model_' + str(sample_args.epoch) + '.tar')

    if os.path.isfile(checkpoint_path):
        print('Loading checkpoint')
        checkpoint = torch.load(checkpoint_path)
        # model_iteration = checkpoint['iteration']
        model_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        print('Loaded checkpoint at {}'.format(model_epoch))

    # Dataset to get data from
    # dataset = [sample_args.test_dataset]

    # dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

    # dataloader.reset_batch_pointer()

    filedir = '../../beijing/'

    grp = My_Graph(sample_args.obs_length + sample_args.pred_length,
                   filedir + 'W.pk', filedir + 'traffic_data.csv')

    my_datapt = My_DataPointer(
        grp.getFrameNum(), sample_args.obs_length + sample_args.pred_length, 1)

    # Construct the ST-graph object
    # stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

    atten_res = []
    mape_mean = []
    loss = 0
    # Variable to maintain total error
    total_error = 0
    final_error = 0
    cnt = 0

    # num = my_datapt.test_num();
    num = 1

    fo = open(save_directory + 'output.txt', "w")

    with torch.no_grad():
        for e in range(num):
            for st in my_datapt.get_test():
                start = time.time()

                # Get the next st
                # x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

                # Construct ST graph
                # stgraph.readGraph(x)

                nodes, edges, nodesPresent, edgesPresent = grp.getSequence(st)
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:sample_args.obs_length], edges[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], edgesPresent[:sample_args.obs_length]
                # ret_nodes, ret_attn = multiPred(obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent, sample_args, net)

                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                # ret_nodes, _, _, _, _, _ = net(nodes, edges, nodesPresent, edgesPresent,
                #                                  hidden_states_node_RNNs, hidden_states_edge_RNNs,
                #                                  cell_states_node_RNNs, cell_states_edge_RNNs)

                ret_nodes, _ = multiPred(nodes, edges, nodesPresent,
                                         edgesPresent, sample_args, net)

                ret_nodes_numpy = ret_nodes.cpu().numpy()
                ret_nodes_numpy = grp.z_inverse(ret_nodes_numpy)
                nodes_numpy = nodes.cpu().numpy()
                nodes_numpy = grp.z_inverse(nodes_numpy)

                # atten_res.append(ret_attn)
                mape = get_MAPE(ret_nodes_numpy[sample_args.obs_length:],
                                nodes_numpy[sample_args.obs_length:],
                                nodesPresent[sample_args.obs_length - 1])

                end = time.time()

                print(f"    mape:{mape}")
                fo.write(f"    mape:{mape}\n")
                with open("../log-beijing-16/log.py", "w") as f:
                    f.write(f"frame=[];\n")
                    for i in range(ret_nodes_numpy.shape[0]):
                        f.write(f"frame.append([\n")
                        for node in range(ret_nodes_numpy.shape[1]):
                            f.write(
                                f"({ret_nodes_numpy[i,node,0]},{nodes_numpy[i,node,0]}),"
                            )
                        f.write("]);\n")
                if (len(mape_mean) == 0):
                    mape_mean = mape
                else:
                    for i in range(len(mape_mean)):
                        mape_mean[i] += mape[i]

    for i in range(len(mape_mean)):
        mape_mean[i] = mape_mean[i] / my_datapt.test_num()
        print(f"time step{i}:mape={mape_mean[i]}\n")
        fo.write(f"time step{i}:mape={mape_mean[i]}\n")
    with open('../log-beijing-16/attention_out.txt', "w") as f:
        f.write(str(atten_res))
Example #6
0
def train(args):
    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            forcePreProcess=False)
    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = os.getcwd() + "\\log\\"
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)

    # Logging file
    log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w+")
    log_file = open(os.path.join(log_directory, "val.txt"), "w")

    # Save directory
    save_directory = os.getcwd() + "\\save\\"
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Open the configuration file
    with open(os.path.join(save_directory, "config.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, "srnn_model_" + str(x) + ".tar")

    # Initialize net
    net = SRNN(args)
    net.load_state_dict(torch.load('modelref/srnn_model_271.tar'),
                        strict=False)
    if args.use_cuda:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5)

    # learning_rate = args.learning_rate
    logging.info("Training begin")
    best_val_loss = 100
    best_epoch = 0

    global plotter
    plotter = VisdomLinePlotter(env_name='main')

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        # dataloader.num_batches = 10. 1 epoch have 10 batches
        for batch in range(dataloader.num_batches):
            start = time.time()
            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])
                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )
                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                # nodes[0] represent all the person's corrdinate show up in  frame 0.
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()

                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()
                # Forward prop
                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.item()
                # embed()
                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            logging.info(
                "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}".
                format(
                    epoch * dataloader.num_batches + batch,
                    args.num_epochs * dataloader.num_batches,
                    epoch,
                    loss_batch,
                    end - start,
                ))
        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        plotter.plot('loss', 'train', 'Class Loss', epoch, loss_epoch)
        # Log it
        log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",")

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data

            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        plotter.plot('loss', 'val', 'Class Loss', epoch, loss_epoch)

        # Record best epoch and best validation loss
        logging.info("(epoch {}), valid_loss = {:.3f}".format(
            epoch, loss_epoch))
        logging.info("Best epoch {}, Best validation loss {}".format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + "\n")

        # Save the model after each epoch
        logging.info("Saving model")
        torch.save(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            checkpoint_path(epoch),
        )

    # Record the best epoch and best validation loss overall
    logging.info("Best epoch {}, Best validation loss {}".format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + "," + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Example #7
0
def main():
    # os.chdir('/home/serene/Documents/KITTIData/GT/')
    # os.chdir('/home/siri0005/srnn-pytorch-master/')#/srnn-pytorch-master

    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    H_path = ['././pedestrians/ewap_dataset/seq_eth/H.txt',
              '././pedestrians/ewap_dataset/seq_hotel/H.txt',
              '././pedestrians/ucy_crowd/data_zara01/H.txt',
              '././pedestrians/ucy_crowd/data_zara02/H.txt',
              '././pedestrians/ucy_crowd/data_students03/H.txt']
    avg = []
    ade = []
    # ,1,2,3,
    # base_dir = '/home/serene/Downloads/srnn-pytorch-master/' #'../MultiNodeAttn_HH/' #'../fine_obstacle/prelu/p_02/'
    base_dir = '../MultiNodeAttn_HH/' #'/home/serene/Downloads/ablation/'
    for i in [1]:
        with open(base_dir+'log/{0}/log_attention/val.txt'.format(i)) as val_f:
            best_val = val_f.readline()
            # e = 45
            [e, val] = best_val.split(',')

        parser = argparse.ArgumentParser()
        # Observed length of the trajectory parameter
        parser.add_argument('--obs_length', type=int, default=8,
                            help='Observed length of the trajectory')
        # Predicted length of the trajectory parameter
        parser.add_argument('--pred_length', type=int, default=12,
                            help='Predicted length of the trajectory')
        # Test dataset
        parser.add_argument('--test_dataset', type=int, default=i,
                            help='Dataset to be tested on')

        # Model to be loaded
        parser.add_argument('--epoch', type=int, default=e,
                            help='Epoch of model to be loaded')

        parser.add_argument('--use_cuda', action="store_true", default=True,
                            help="Use GPU or CPU")

        # Parse the parameters
        sample_args = parser.parse_args()

        # Save directory
        save_directory = base_dir+'save/{0}/save_attention'.format(i)
        plot_directory = base_dir +  '/selected_plots/' #'plot_1/'
        # save_directory = './srnn-pytorch-master/fine_obstacle/save/{0}/save_attention'.format(i)
        #'/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master/save/pixel_data/100e/'
        #'/home/serene/Documents/InVehicleCamera/save_kitti/'

        # save_directory += str(sample_args.test_dataset)+'/'
        # save_directory += 'save_attention'

        # Define the path for the config file for saved args
        with open(os.path.join(save_directory, 'config.pkl'), 'rb') as f:
            saved_args = pickle.load(f)

        # Initialize net
        net = SRNN(saved_args, True)
        net.cuda()

        ## TODO: visualize trained weights
        # plt.imshow(net.humanNodeRNN.edge_embed.weight)
        # plt.colorbar()
        # plt.show()
        checkpoint_path = os.path.join(save_directory, 'srnn_model_'+str(sample_args.epoch)+'.tar')

        if os.path.isfile(checkpoint_path):
            print('Loading checkpoint')
            checkpoint = torch.load(checkpoint_path)
            # model_iteration = checkpoint['iteration']
            model_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            print('Loaded checkpoint at {}'.format(model_epoch))

        H_mat = np.loadtxt(H_path[i])

        avg = []
        ade = []
        # Dataset to get data from
        dataset = [sample_args.test_dataset]
        for c in range(30):
            dataloader = DataLoader(1, sample_args.pred_length + sample_args.obs_length, dataset, True, infer=True)

            dataloader.reset_batch_pointer()

            # Construct the ST-graph object
            stgraph = ST_GRAPH(1, sample_args.pred_length + sample_args.obs_length)

            results = []

            # Variable to maintain total error
            total_error = 0
            final_error = 0
            minimum = 1000
            min_final = 1000
            for batch in range(dataloader.num_batches):
                start = time.time()

                # Get the next batch
                x, _, frameIDs, d = dataloader.next_batch(randomUpdate=False)

                # Construct ST graph
                stgraph.readGraph(x, ds_ptr=d,threshold=1.0)

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()
                #obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float(), volatile=True).cuda()
                edges = Variable(torch.from_numpy(edges).float(), volatile=True).cuda()

                # obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                # obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()
                # NOTE: obs_ : observed
                # Separate out the observed part of the trajectory
                obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent = nodes[:sample_args.obs_length], edges[:sample_args.obs_length], nodesPresent[:sample_args.obs_length], edgesPresent[:sample_args.obs_length]

                # Separate out the observed obstacles in a given sequence
                # obsnodes_v, obsEdges_v , obsNodesPresent_v , obsEdgesPresent_v = obsNodes[:sample_args.obs_length], obsEdges[:sample_args.obs_length], obsNodesPresent[:sample_args.obs_length], obsEdgesPresent[:sample_args.obs_length]

                # Sample function
                ret_nodes, ret_attn = sample(obs_nodes, obs_edges, obs_nodesPresent, obs_edgesPresent,sample_args, net, nodes, edges, nodesPresent)
                    # , obsnodes_v , obsEdges_v, obsNodesPresent_v,
                    #   obsEdgesPresent_v ,  )

                # Compute mean and final displacement error
                total_error += get_mean_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                              nodesPresent[sample_args.obs_length - 1],
                                              nodesPresent[sample_args.obs_length:], H_mat, i)

                final_error += get_final_error(ret_nodes[sample_args.obs_length:].data, nodes[sample_args.obs_length:].data,
                                               nodesPresent[sample_args.obs_length - 1],
                                               nodesPresent[sample_args.obs_length:], H_mat, i)

                # if total_error < minimum:
                #     minimum = total_error
                # if final_error < min_final:
                #     min_final = final_error

                end = time.time()

                # Store results
                results.append((nodes.data.cpu().numpy(), ret_nodes.data.cpu().numpy(), nodesPresent, sample_args.obs_length, ret_attn, frameIDs))
                # zfill = 3

                # for i in range(len(results)):
                #     skip = str(int(results[i][5][0][8])).zfill(zfill)
                #     # img_file = '/home/serene/Documents/video/hotel/frame-{0}.jpg'.format(skip)
                #     # for j in range(20):
                #     #     if i == 40:
                #
                #     img_file = '/home/serene/Documents/copy_srnn_pytorch/data/ucy/zara/zara.png'
                #     name = plot_directory  + 'sequence_zara' + str(i)  # /pedestrian_1
                #     # for k in range(20):
                #     vis.plot_trajectories(results[i][0], results[i][1], results[i][2], results[i][3], name,
                #                       plot_directory, img_file, 1)
                #     if int(skip) >= 999 and zfill < 4:
                #         zfill = zfill + 1
                #     elif int(skip) >= 9999 and zfill < 5:
                #         zfill = zfill + 1

                # Reset the ST graph
                stgraph.reset()

            print('Total mean error of the model is ', total_error / dataloader.num_batches)
            print('Total final error of the model is ', final_error / dataloader.num_batches)
            ade.append(total_error / dataloader.num_batches)
            avg.append(final_error / dataloader.num_batches)
            print('Saving results')
            with open(os.path.join(save_directory, 'results.pkl'), 'wb') as f:
                pickle.dump(results, f)

        print('average FDE', np.average(avg))
        print('average ADE', np.average(ade))

        with open(os.path.join(save_directory, 'sampling_results.txt'), 'wb') as o:
            np.savetxt(os.path.join(save_directory, 'sampling_results.txt'), (ade, avg) , fmt='%.03e')