예제 #1
0
def sample(nodes, nodesPresent, grid, args, net, true_nodes, true_nodesPresent, true_grid, saved_args, dimensions):
    '''
    The sample function
    params:
    nodes: Input positions
    nodesPresent: Peds present in each frame
    args: arguments
    net: The model
    true_nodes: True positions
    true_nodesPresent: The true peds present in each frame
    true_grid: The true grid masks
    saved_args: Training arguments
    dimensions: The dimensions of the dataset
    '''
    # Number of peds in the sequence
    numNodes = nodes.size()[1]

    # Construct variables for hidden and cell states
    hidden_states = Variable(torch.zeros(numNodes, net.args.rnn_size), volatile=True).cuda()
    cell_states = Variable(torch.zeros(numNodes, net.args.rnn_size), volatile=True).cuda()

    # For the observed part of the trajectory
    for tstep in range(args.obs_length-1):
        # Do a forward prop
        out_obs, hidden_states, cell_states = net(nodes[tstep].view(1, numNodes, 2), [grid[tstep]], [nodesPresent[tstep]], hidden_states, cell_states)
        # loss_obs = Gaussian2DLikelihood(out_obs, nodes[tstep+1].view(1, numNodes, 2), [nodesPresent[tstep+1]])

    # Initialize the return data structure
    ret_nodes = Variable(torch.zeros(args.obs_length+args.pred_length, numNodes, 2), volatile=True).cuda()
    ret_nodes[:args.obs_length, :, :] = nodes.clone()

    # Last seen grid
    prev_grid = grid[-1].clone()

    # For the predicted part of the trajectory
    for tstep in range(args.obs_length-1, args.pred_length + args.obs_length - 1):
        # Do a forward prop
        outputs, hidden_states, cell_states = net(ret_nodes[tstep].view(1, numNodes, 2), [prev_grid], [nodesPresent[args.obs_length-1]], hidden_states,cell_states)
        # loss_pred = Gaussian2DLikelihoodInference(outputs, true_nodes[tstep+1].view(1, numNodes, 2), nodesPresent[args.obs_length-1], [true_nodesPresent[tstep+1]])

        # Extract the mean, std and corr of the bivariate Gaussian
        mux, muy, sx, sy, corr = getCoef(outputs)
        # Sample from the bivariate Gaussian
        next_x, next_y = sample_gaussian_2d(mux.data, muy.data, sx.data, sy.data, corr.data, nodesPresent[args.obs_length-1])

        # Store the predicted position
        ret_nodes[tstep + 1, :, 0] = next_x
        ret_nodes[tstep + 1, :, 1] = next_y

        # List of nodes at the last time-step (assuming they exist until the end)
        list_of_nodes = Variable(torch.LongTensor(nodesPresent[args.obs_length-1]), volatile=True).cuda()
        # Get their predicted positions
        current_nodes = torch.index_select(ret_nodes[tstep+1], 0, list_of_nodes)

        # Compute the new grid masks with the predicted positions
        prev_grid = getGridMaskInference(current_nodes.data.cpu().numpy(), dimensions, saved_args.neighborhood_size, saved_args.grid_size)
        prev_grid = Variable(torch.from_numpy(prev_grid).float(), volatile=True).cuda()

    return ret_nodes
예제 #2
0
def Gaussian2DLikelihood(outputs, targets, nodesPresent, pred_length):
    '''
    Computes the likelihood of predicted locations under a bivariate Gaussian distribution
    params:
    outputs: Torch variable containing tensor of shape seq_length x numNodes x output_size (e.g. 20*10*5)
    targets: Torch variable containing tensor of shape seq_length x numNodes x input_size
    nodesPresent : A list of lists, of size seq_length. Each list contains the nodeIDs that are present in the frame
    '''

    # Get the sequence length

    seq_length = outputs.size()[0]
    # Get the observed length
    obs_length = seq_length - pred_length

    # Extract mean, std devs and correlation
    mux, muy, sx, sy, corr = getCoef(outputs)
    # print('outputs={},size={}'.format(outputs,outputs.size))
    # print('input={},size={}'.format(targets, outputs.size))
    #print('mux, muy, sx, sy, corr:{}{}{}{}'.format(mux, muy, sx, sy, corr))

    # Compute factors
    normx = targets[:, :, 0] - mux
    normy = targets[:, :, 1] - muy
    sxsy = sx * sy
    z = torch.pow((normx / sx), 2) + torch.pow(
        (normy / sy), 2) - 2 * ((corr * normx * normy) / sxsy)
    negRho = 1 - torch.pow(corr, 2)

    # Numerator
    result = torch.exp(-z / (2 * negRho))
    # Normalization factor
    denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))

    # Final PDF calculation
    result = result / denom

    # Numerical stability
    epsilon = 1e-20
    result = -torch.log(torch.clamp(result, min=epsilon))
    #print('result={}'.format(result))
    # Compute the loss across all frames and all nodes
    loss = 0
    counter = 0

    for framenum in range(obs_length, seq_length):
        nodeIDs = nodesPresent[framenum]

        for nodeID in nodeIDs:

            loss = loss + result[framenum, nodeID]
            counter = counter + 1

    if counter != 0:
        return loss / counter
    else:
        return loss
예제 #3
0
def Gaussian2DLikelihoodInference(outputs, targets, assumedNodesPresent,
                                  nodesPresent, use_cuda):
    '''
    Computes the likelihood of predicted locations under a bivariate Gaussian distribution at test time
    params:
    outputs : predicted locations
    targets : true locations
    assumedNodesPresent : Nodes assumed to be present in each frame in the sequence
    nodesPresent : True nodes present in each frame in the sequence
    '''
    # Extract mean, std devs and correlation
    mux, muy, sx, sy, corr = getCoef(outputs)

    # Compute factors
    normx = targets[:, :, 0] - mux
    normy = targets[:, :, 1] - muy
    sxsy = sx * sy
    z = (normx / sx)**2 + (normy / sy)**2 - 2 * ((corr * normx * normy) / sxsy)
    negRho = 1 - corr**2

    # Numerator
    result = torch.exp(-z / (2 * negRho))
    # Normalization factor
    denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))

    # Final PDF calculation
    result = result / denom

    # Numerical stability
    epsilon = 1e-20

    result = -torch.log(torch.clamp(result, min=epsilon))

    # Compute the loss
    loss = Variable(torch.zeros(1))
    if use_cuda:
        loss = loss.cuda()
    counter = 0

    for framenum in range(outputs.size()[0]):
        nodeIDs = nodesPresent[framenum]

        for nodeID in nodeIDs:
            if nodeID not in assumedNodesPresent:
                # If the node wasn't assumed to be present, don't compute loss for it
                continue
            loss = loss + result[framenum, nodeID]
            counter = counter + 1

    if counter != 0:
        return loss / counter
    else:
        return loss
예제 #4
0
def Gaussian2DLikelihood(outputs, targets, nodesPresent, pred_length):
    '''
    Parameters:

    outputs: Torch variable containing tensor of shape seq_length x numNodes x 1 x output_size
    targets: Torch variable containing tensor of shape seq_length x numNodes x 1 x input_size
    nodesPresent : A list of lists, of size seq_length. Each list contains the nodeIDs that are present in the frame
    '''
    seq_length = outputs.size()[0]
    obs_length = seq_length - pred_length

    # Extract mean, std devs and correlation
    mux, muy, sx, sy, corr = getCoef(outputs)

    # Compute factors
    normx = targets[:, :, 0] - mux
    normy = targets[:, :, 1] - muy
    sxsy = sx * sy

    z = (normx / sx)**2 + (normy / sy)**2 - 2 * ((corr * normx * normy) / sxsy)
    negRho = 1 - corr**2

    # Numerator
    result = torch.exp(-z / (2 * negRho))
    # Normalization factor
    denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))

    # Final PDF calculation
    result = result / denom

    # Numerical stability
    epsilon = 1e-20

    result = -torch.log(torch.clamp(result, min=epsilon))

    loss = 0
    counter = 0

    for framenum in range(obs_length, seq_length):
        nodeIDs = nodesPresent[framenum]

        for nodeID in nodeIDs:

            loss = loss + result[framenum, nodeID]
            counter = counter + 1

    if counter != 0:
        return loss / counter
    else:
        return loss
예제 #5
0
def sample(
    nodes,
    edges,
    nodesPresent,
    edgesPresent,
    args,
    net,
    true_nodes,
    true_edges,
    true_nodesPresent,
):
    """
    Sample function
    Parameters
    ==========

    nodes : A tensor of shape obs_length x numNodes x 2
    Each row contains (x, y)

    edges : A tensor of shape obs_length x numNodes x numNodes x 2
    Each row contains the vector representing the edge
    If edge doesn't exist, then the row contains zeros

    nodesPresent : A list of lists, of size obs_length
    Each list contains the nodeIDs that are present in the frame

    edgesPresent : A list of lists, of size obs_length
    Each list contains tuples of nodeIDs that have edges in the frame

    args : Sampling Arguments

    net : The network

    Returns
    =======

    ret_nodes : A tensor of shape (obs_length + pred_length) x numNodes x 2
    Contains the true and predicted positions of all the nodes
    """
    # Number of nodes
    numNodes = nodes.size()[1]

    # Initialize hidden states for the nodes
    h_nodes = Variable(torch.zeros(numNodes, net.args.node_rnn_size),
                       volatile=True)
    h_edges = Variable(torch.zeros(numNodes * numNodes,
                                   net.args.edge_rnn_size),
                       volatile=True)
    c_nodes = Variable(torch.zeros(numNodes, net.args.node_rnn_size),
                       volatile=True)
    c_edges = Variable(torch.zeros(numNodes * numNodes,
                                   net.args.edge_rnn_size),
                       volatile=True)
    h_super_node = Variable(torch.zeros(3, net.args.node_rnn_size),
                            volatile=True)
    c_super_node = Variable(torch.zeros(3, net.args.node_rnn_size),
                            volatile=True)
    h_super_edges = Variable(torch.zeros(3, net.args.edge_rnn_size),
                             volatile=True)
    c_super_edges = Variable(torch.zeros(3, net.args.edge_rnn_size),
                             volatile=True)
    if args.use_cuda:
        h_nodes = h_nodes.cuda()
        h_edges = h_edges.cuda()
        c_nodes = c_nodes.cuda()
        c_edges = c_edges.cuda()
        h_super_node = h_super_node.cuda()
        c_super_node = c_super_node.cuda()
        h_super_edges = h_super_edges.cuda()
        c_super_edges = c_super_edges.cuda()

    # Propagate the observed length of the trajectory
    for tstep in range(args.obs_length - 1):
        # Forward prop
        out_obs, h_nodes, h_edges, c_nodes, c_edges, h_super_node, h_super_edges, c_super_node, c_super_edges, _ = net(
            nodes[tstep].view(1, numNodes, 2),
            edges[tstep].view(1, numNodes * numNodes, 2),
            [nodesPresent[tstep]],
            [edgesPresent[tstep]],
            h_nodes,
            h_edges,
            c_nodes,
            c_edges,
            h_super_node,
            h_super_edges,
            c_super_node,
            c_super_edges,
        )
        # loss_obs = Gaussian2DLikelihood(out_obs, nodes[tstep+1].view(1, numNodes, 2), [nodesPresent[tstep+1]])

    # Initialize the return data structures
    ret_nodes = Variable(torch.zeros(args.obs_length + args.pred_length,
                                     numNodes, 2),
                         volatile=True)
    if args.use_cuda:
        ret_nodes = ret_nodes.cuda()
    ret_nodes[:args.obs_length, :, :] = nodes.clone()

    ret_edges = Variable(
        torch.zeros((args.obs_length + args.pred_length), numNodes * numNodes,
                    2),
        volatile=True,
    )
    if args.use_cuda:
        ret_edges = ret_edges.cuda()
    ret_edges[:args.obs_length, :, :] = edges.clone()

    ret_attn = []

    # Propagate the predicted length of trajectory (sampling from previous prediction)
    for tstep in range(args.obs_length - 1,
                       args.pred_length + args.obs_length - 1):
        # TODO Not keeping track of nodes leaving the frame (or new nodes entering the frame, which I don't think we can do anyway)
        # Forward prop
        outputs, h_nodes, h_edges, c_nodes, c_edges, h_super_node, h_super_edges, c_super_node, c_super_edges, attn_w = net(
            ret_nodes[tstep].view(1, numNodes, 2),
            ret_edges[tstep].view(1, numNodes * numNodes, 2),
            [nodesPresent[args.obs_length - 1]],
            [edgesPresent[args.obs_length - 1]],
            h_nodes,
            h_edges,
            c_nodes,
            c_edges,
            h_super_node,
            h_super_edges,
            c_super_node,
            c_super_edges,
        )
        mux, muy, sx, sy, corr = getCoef(outputs)
        next_x, next_y = sample_gaussian_2d(
            mux.data,
            muy.data,
            sx.data,
            sy.data,
            corr.data,
            nodesPresent[args.obs_length - 1],
        )

        ret_nodes[tstep + 1, :, 0] = next_x
        ret_nodes[tstep + 1, :, 1] = next_y

        # Compute edges
        # TODO Currently, assuming edges from the last observed time-step will stay for the entire prediction length
        ret_edges[tstep + 1, :, :] = compute_edges(
            ret_nodes.data, tstep + 1, edgesPresent[args.obs_length - 1],
            args.use_cuda)
        # Store computed attention weights
        ret_attn.append(attn_w[0])

    return ret_nodes, ret_attn
예제 #6
0
def sample(x_seq,
           Pedlist,
           args,
           net,
           true_x_seq,
           true_Pedlist,
           saved_args,
           dimensions,
           dataloader,
           look_up,
           num_pedlist,
           is_gru,
           grid=None):
    '''
    The sample function
    params:
    x_seq: Input positions
    Pedlist: Peds present in each frame
    args: arguments
    net: The model
    true_x_seq: True positions
    true_Pedlist: The true peds present in each frame
    saved_args: Training arguments
    dimensions: The dimensions of the dataset
    target_id: ped_id number that try to predict in this sequence
    '''
    # Number of peds in the sequence
    numx_seq = len(look_up)

    with torch.no_grad():
        # Construct variables for hidden and cell states
        hidden_states = Variable(torch.zeros(numx_seq, net.args.rnn_size))
        if args.use_cuda:
            hidden_states = hidden_states.cuda()
        if not is_gru:
            cell_states = Variable(torch.zeros(numx_seq, net.args.rnn_size))
            if args.use_cuda:
                cell_states = cell_states.cuda()
        else:
            cell_states = None

        ret_x_seq = Variable(
            torch.zeros(args.obs_length + args.pred_length, numx_seq, 2))

        # Initialize the return data structure
        if args.use_cuda:
            ret_x_seq = ret_x_seq.cuda()

        # For the observed part of the trajectory
        for tstep in range(args.obs_length - 1):
            if grid is None:  #vanilla lstm
                # Do a forward prop
                out_obs, hidden_states, cell_states = net(
                    x_seq[tstep].view(1, numx_seq, 2), hidden_states,
                    cell_states, [Pedlist[tstep]], [num_pedlist[tstep]],
                    dataloader, look_up)
            else:
                # Do a forward prop
                out_obs, hidden_states, cell_states = net(
                    x_seq[tstep].view(1, numx_seq, 2), [grid[tstep]],
                    hidden_states, cell_states, [Pedlist[tstep]],
                    [num_pedlist[tstep]], dataloader, look_up)
            # loss_obs = Gaussian2DLikelihood(out_obs, x_seq[tstep+1].view(1, numx_seq, 2), [Pedlist[tstep+1]])

            # Extract the mean, std and corr of the bivariate Gaussian
            mux, muy, sx, sy, corr = getCoef(out_obs)
            # Sample from the bivariate Gaussian
            next_x, next_y = sample_gaussian_2d(mux.data, muy.data, sx.data,
                                                sy.data, corr.data,
                                                true_Pedlist[tstep], look_up)
            ret_x_seq[tstep + 1, :, 0] = next_x
            ret_x_seq[tstep + 1, :, 1] = next_y

        ret_x_seq[:args.obs_length, :, :] = x_seq.clone()

        # Last seen grid
        if grid is not None:  #no vanilla lstm
            prev_grid = grid[-1].clone()

        #assign last position of observed data to temp
        #temp_last_observed = ret_x_seq[args.obs_length-1].clone()
        #ret_x_seq[args.obs_length-1] = x_seq[args.obs_length-1]

        # For the predicted part of the trajectory
        for tstep in range(args.obs_length - 1,
                           args.pred_length + args.obs_length - 1):
            # Do a forward prop
            if grid is None:  #vanilla lstm
                outputs, hidden_states, cell_states = net(
                    ret_x_seq[tstep].view(1, numx_seq, 2), hidden_states,
                    cell_states, [true_Pedlist[tstep]], [num_pedlist[tstep]],
                    dataloader, look_up)
            else:
                outputs, hidden_states, cell_states = net(
                    ret_x_seq[tstep].view(1, numx_seq, 2), [prev_grid],
                    hidden_states, cell_states, [true_Pedlist[tstep]],
                    [num_pedlist[tstep]], dataloader, look_up)

            # Extract the mean, std and corr of the bivariate Gaussian
            mux, muy, sx, sy, corr = getCoef(outputs)
            # Sample from the bivariate Gaussian
            next_x, next_y = sample_gaussian_2d(mux.data, muy.data, sx.data,
                                                sy.data, corr.data,
                                                true_Pedlist[tstep], look_up)

            # Store the predicted position
            ret_x_seq[tstep + 1, :, 0] = next_x
            ret_x_seq[tstep + 1, :, 1] = next_y

            # List of x_seq at the last time-step (assuming they exist until the end)
            true_Pedlist[tstep + 1] = [
                int(_x_seq) for _x_seq in true_Pedlist[tstep + 1]
            ]
            next_ped_list = copy.deepcopy(true_Pedlist[tstep + 1])
            converted_pedlist = [look_up[_x_seq] for _x_seq in next_ped_list]
            list_of_x_seq = Variable(torch.LongTensor(converted_pedlist))

            if args.use_cuda:
                list_of_x_seq = list_of_x_seq.cuda()

            #Get their predicted positions
            current_x_seq = torch.index_select(ret_x_seq[tstep + 1], 0,
                                               list_of_x_seq)

            if grid is not None:  #no vanilla lstm
                # Compute the new grid masks with the predicted positions
                if args.method == 2:  #obstacle lstm
                    prev_grid = getGridMask(current_x_seq.data.cpu(),
                                            dimensions,
                                            len(true_Pedlist[tstep + 1]),
                                            saved_args.neighborhood_size,
                                            saved_args.grid_size, True)
                elif args.method == 1:  #social lstm
                    prev_grid = getGridMask(current_x_seq.data.cpu(),
                                            dimensions,
                                            len(true_Pedlist[tstep + 1]),
                                            saved_args.neighborhood_size,
                                            saved_args.grid_size)

                prev_grid = Variable(torch.from_numpy(prev_grid).float())
                if args.use_cuda:
                    prev_grid = prev_grid.cuda()

        #ret_x_seq[args.obs_length-1] = temp_last_observed

        return ret_x_seq
예제 #7
0
def sample(
    nodes,
    edges,
    nodesPresent,
    edgesPresent,
    args,
    net,
    true_nodes,
    true_edges,
    true_nodesPresent,
):
    """
    Sample function
    Parameters
    ==========

    nodes : A tensor of shape obs_length x numNodes x 2
    Each row contains (x, y)

    edges : A tensor of shape obs_length x numNodes x numNodes x 2
    Each row contains the vector representing the edge
    If edge doesn't exist, then the row contains zeros

    nodesPresent : A list of lists, of size obs_length
    Each list contains the nodeIDs that are present in the frame

    edgesPresent : A list of lists, of size obs_length
    Each list contains tuples of nodeIDs that have edges in the frame

    args : Sampling Arguments

    net : The network

    Returns
    =======

    ret_nodes : A tensor of shape (obs_length + pred_length) x numNodes x 2
    Contains the true and predicted positions of all the nodes
    """
    # Number of nodes
    numNodes = nodes.size()[1]

    # Initialize hidden states for the nodes
    h_nodes = Variable(torch.zeros(numNodes, net.args.node_rnn_size),
                       volatile=True)
    h_edges = Variable(torch.zeros(numNodes * numNodes,
                                   net.args.edge_rnn_size),
                       volatile=True)
    c_nodes = Variable(torch.zeros(numNodes, net.args.node_rnn_size),
                       volatile=True)
    c_edges = Variable(torch.zeros(numNodes * numNodes,
                                   net.args.edge_rnn_size),
                       volatile=True)
    h_super_node = Variable(torch.zeros(3, net.args.node_rnn_size),
                            volatile=True)
    c_super_node = Variable(torch.zeros(3, net.args.node_rnn_size),
                            volatile=True)
    h_super_edges = Variable(torch.zeros(3, net.args.edge_rnn_size),
                             volatile=True)
    c_super_edges = Variable(torch.zeros(3, net.args.edge_rnn_size),
                             volatile=True)
    if args.use_cuda:
        h_nodes = h_nodes.cuda()
        h_edges = h_edges.cuda()
        c_nodes = c_nodes.cuda()
        c_edges = c_edges.cuda()
        h_super_node = h_super_node.cuda()
        c_super_node = c_super_node.cuda()
        h_super_edges = h_super_edges.cuda()
        c_super_edges = c_super_edges.cuda()

    # NOTE: Propagate the observed length of the trajectory// I think this step is for obtaining '''{hidden, cell}''' states
    for tstep in range(args.obs_length - 1):  #till observed length
        # Forward prop
        out_obs, h_nodes, h_edges, c_nodes, c_edges, h_super_node, h_super_edges, c_super_node, c_super_edges, _ = net(
            nodes[tstep].view(1, numNodes, 2),
            edges[tstep].view(1, numNodes * numNodes, 2),
            [nodesPresent[tstep]],
            [edgesPresent[tstep]],
            h_nodes,
            h_edges,
            c_nodes,
            c_edges,
            h_super_node,
            h_super_edges,
            c_super_node,
            c_super_edges,
        )
        # loss_obs = Gaussian2DLikelihood(out_obs, nodes[tstep+1].view(1, numNodes, 2), [nodesPresent[tstep+1]])

    # Initialize the return data structures
    ret_nodes = Variable(torch.zeros(args.obs_length + args.pred_length,
                                     numNodes, 2),
                         volatile=True)
    if args.use_cuda:
        ret_nodes = ret_nodes.cuda()
    ret_nodes[:args.obs_length, :, :] = nodes.clone()  #### Ground Truth

    ret_edges = Variable(
        torch.zeros((args.obs_length + args.pred_length), numNodes * numNodes,
                    2),
        volatile=True,
    )
    if args.use_cuda:
        ret_edges = ret_edges.cuda()
    ret_edges[:args.obs_length, :, :] = edges.clone()  #### Ground Truth

    ret_attn = []

    # Propagate the predicted length of trajectory (sampling from previous prediction)
    # WHICH IS PREDICTION STEP!
    for tstep in range(args.obs_length - 1,
                       args.pred_length + args.obs_length - 1):
        # TODO Not keeping track of nodes leaving the frame (or new nodes entering the frame, which I don't think we can do anyway)
        # Forward prop// Recursive//
        # 이 알고리즘은 recursive하게 현 state{t}에서 다음 state를 단순히 잇는 LSTM cell을 학습시키는 것 이기에 sequence 데이터가 들어오게 되면
        # cell을 이용하여 다음 output_{t+1}을 예측하고, 이 output_{t+1}이 다시 cell로 들어가서 output_{t+2}를 내뱉는 구조.
        outputs, h_nodes, h_edges, c_nodes, c_edges, h_super_node, h_super_edges, c_super_node, c_super_edges, attn_w = net(
            ret_nodes[tstep].view(1, numNodes, 2),
            ret_edges[tstep].view(1, numNodes * numNodes, 2),
            [nodesPresent[args.obs_length - 1]],
            [edgesPresent[args.obs_length - 1]],
            h_nodes,
            h_edges,
            c_nodes,
            c_edges,
            h_super_node,
            h_super_edges,
            c_super_node,
            c_super_edges,
        )  ####이때 output은 현재 frame에 존재하는 모든 node들의 (mux,muy,std,corr)

        mux, muy, sx, sy, corr = getCoef(outputs)
        #TODO
        next_x, next_y = sample_gaussian_2d(  #gaussian 2d를 구축한 후, 여기서 sampling된 한 점을 x,y로 하는데 이게 맞나...?
            mux.data,
            muy.data,
            sx.data,
            sy.data,
            corr.data,
            nodesPresent[args.obs_length -
                         1],  #오직 현scene에 존재하는 object들의 trajectory를 예측하는 것.
        )

        ret_nodes[tstep + 1, :, 0] = next_x
        ret_nodes[tstep + 1, :, 1] = next_y

        # Compute edges
        # TODO Currently, assuming edges from the last observed time-step will stay for the entire prediction length
        ret_edges[tstep + 1, :, :] = compute_edges(
            ret_nodes.data, tstep + 1, edgesPresent[args.obs_length - 1],
            args.use_cuda)
        # Store computed attention weights
        ret_attn.append(attn_w[0])

    return ret_nodes, ret_attn