Пример #1
0
def train(args):
    datasets = [i for i in range(5)]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)
    # datasets = [0]
    # args.leaveDataset = 0

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            datasets,
                            forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = 'log_w_goal/'
    log_directory += str(args.leaveDataset) + '/'
    log_directory += 'log_attention'

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    save_directory = 'save_w_goal/'
    save_directory += str(args.leaveDataset) + '/'
    save_directory += 'save_attention'

    # Open the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_' + str(x) + '.tar')

    # Initialize net
    net = SRNN(args)
    net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)

    learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True)
            import pdb
            pdb.set_trace()

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _ = net(
                    nodes[:args.seq_length], edges[:args.seq_length],
                    nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs,
                    cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.
                  format(epoch * dataloader.num_batches + batch,
                         args.num_epochs * dataloader.num_batches, epoch,
                         loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                outputs, _, _, _, _, _ = net(
                    nodes[:args.seq_length], edges[:args.seq_length],
                    nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs,
                    cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.data[0]

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + '\n')

        # Save the model after each epoch
        print('Saving model')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch))

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Пример #2
0
def train(args):
    datasets = [i for i in range(5)]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)
    # datasets = [0]
    # args.leaveDataset = 0
    #
    # Construct the ST-graph object
    filedir = '../beijing/'
    grp = My_Graph(args.seq_length + 1, filedir + 'W.pk',
                   filedir + 'traffic_data.csv')

    my_datapt = My_DataPointer(grp.getFrameNum(), args.seq_length + 1,
                               args.batch_size)

    # Log directory
    log_directory = '../log-beijing-16/'
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)
    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    save_directory = '../save-beijing-16/'
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    # Open the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_' + str(x) + '.tar')

    # Initialize net
    net = SRNN(args)
    net.cuda()

    if (args.load >= 0):
        load_net(save_directory, args.load, net)

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)

    learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    print(f"train_batches={my_datapt.num_batches()}")
    print(f"val_batches={my_datapt.valid_num_batches()}")

    # Training
    for epoch in range(args.num_epochs):
        log_output = open(log_directory + 'log_output.py', "w")
        log_output.write("pairs=[\n")

        loss_epoch = 0
        my_datapt.train_reset()

        # For each batch
        for batch in range(my_datapt.num_batches()):
            start = time.time()
            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for st in my_datapt.get_batch():
                # Construct the graph for the current sequence
                t1 = time.time()
                nodes, edges, nodesPresent, edgesPresent = grp.getSequence(st)

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes,
                                args.human_human_edge_rnn_size)).cuda()

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()
                #                print(f"grp time = {time.time()-t1}")
                # Forward prop
                outputs, _, _, _, _, _ = net(
                    nodes[:args.seq_length], edges[:args.seq_length],
                    nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs,
                    cell_states_node_RNNs, cell_states_edge_RNNs)
                log_output.write(
                    f"[{outputs[-1,:,0].data.cpu().numpy().tolist()},{nodes[-1,:,0].data.cpu().numpy().tolist()}],\n"
                )
                #                print(f"forward time = {time.time()-t1}")
                # Compute loss
                loss = getL2Loss(outputs, nodes[1:], nodesPresent[1:],
                                 args.pred_length)
                print(f"start={st},loss={loss}")
                #                print(f"loss time = {time.time()-t1}")
                loss_batch += loss.data

                # Compute gradients
                loss.backward()
                #                print(f"backward time = {time.time()-t1}")
                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()


#                print(f"step time = {time.time()-t1}")

            end = time.time()
            loss_batch = loss_batch / args.batch_size
            loss_epoch += loss_batch

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.
                  format(epoch * my_datapt.num_batches() + batch,
                         args.num_epochs * my_datapt.num_batches(), epoch,
                         loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= my_datapt.num_batches()
        # Log it
        log_file_curve.write(str(epoch) + ',' + str(loss_epoch) + ',')

        log_output.write("]\n")

        # Save the model after each epoch
        print('Saving model')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch))

        # Validation
        my_datapt.val_reset()
        loss_epoch = 0

        with torch.no_grad():
            for batch in range(my_datapt.valid_num_batches()):

                # Loss for this batch
                loss_batch = 0

                for start in my_datapt.get_batch_val():

                    nodes, edges, nodesPresent, edgesPresent = grp.getSequence(
                        start)

                    # Convert to cuda variables
                    nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                    edges = Variable(torch.from_numpy(edges).float()).cuda()

                    # Define hidden states
                    numNodes = nodes.size()[1]
                    hidden_states_node_RNNs = Variable(
                        torch.zeros(numNodes,
                                    args.human_node_rnn_size)).cuda()
                    hidden_states_edge_RNNs = Variable(
                        torch.zeros(numNodes * numNodes,
                                    args.human_human_edge_rnn_size)).cuda()
                    cell_states_node_RNNs = Variable(
                        torch.zeros(numNodes,
                                    args.human_node_rnn_size)).cuda()
                    cell_states_edge_RNNs = Variable(
                        torch.zeros(numNodes * numNodes,
                                    args.human_human_edge_rnn_size)).cuda()

                    outputs, _, _, _, _, _ = net(
                        nodes[:args.seq_length], edges[:args.seq_length],
                        nodesPresent[:-1], edgesPresent[:-1],
                        hidden_states_node_RNNs, hidden_states_edge_RNNs,
                        cell_states_node_RNNs, cell_states_edge_RNNs)

                    # Compute loss
                    loss = getL2Loss(outputs, nodes[1:], nodesPresent[1:],
                                     args.pred_length)

                    loss_batch += loss.data

                loss_batch = loss_batch / args.batch_size
                loss_epoch += loss_batch

            loss_epoch = loss_epoch / my_datapt.valid_num_batches()

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + '\n')

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Пример #3
0
def train(args):
    ## 19th Feb : move training files for 3 && 4 to fine_obstacle under copy_2

    # os.chdir('/home/serene/PycharmProjects/srnn-pytorch-master/')
    # context_factor is an experiment for including potential destinations of pedestrians in the graph.
    # did not yield good improvement
    os.chdir('/home/serene/Documents/copy_srnn_pytorch/srnn-pytorch-master')
    # os.chdir('/home/siri0005/srnn-pytorch-master')
    # base_dir = '../fine_obstacle/prelu/p_02/'

    base_dir =  '../fine_obstacle/prelu/p_02/'
        # '../MultiNodeAttn_HH/'
    # os.chdir('/home/serene/Documents/KITTIData/GT')

    datasets = [i for i in [0,1,2,3,4]]
    # Remove the leave out dataset from the datasets
    datasets.remove(args.leaveDataset)
    # datasets = [0]
    # args.leaveDataset = 0

    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size, args.seq_length + 1, datasets, forcePreProcess=True)

    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    # log_directory = './log/world_data/normalized_01/'
    log_directory = base_dir+'log/'
    log_directory += str(args.leaveDataset) + '/'
    log_directory += 'log_attention/'

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory
    # save_directory = './save/world_data/normalized_01/'
    save_directory = base_dir+'save/'
    save_directory += str(args.leaveDataset)+'/'
    save_directory += 'save_attention/'

    # log RELU parameter
    param_log_dir = save_directory + 'param_log.txt'
    param_log = open(param_log_dir , 'w')

    # Open the configuration file
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar')

    # Initialize net
    net = SRNN(args)
    # srnn_model = SRNN(args)

    # net = torch.nn.DataParallel(srnn_model)

    # CUDA_VISIBLE_DEVICES = 1
    net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)

    # learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0
    # start_epoch = 0

    # if args.leaveDataset == 3:
    # start_epoch =checkpoint_path(0)

    # ckp = torch.load(checkpoint_path(3))
    # net.load_state_dict(ckp['state_dict'])
    # optimizer.load_state_dict(ckp['optimizer_state_dict'])

    # ckp = torch.load(checkpoint_path(4))
    # net.load_state_dict(ckp['state_dict'])
    # optimizer.load_state_dict(ckp['optimizer_state_dict'])
    # last_epoch = ckp['epoch']
    # rang = range(last_epoch, args.num_epochs, 1)
    rang = range(args.num_epochs)

    # Training
    for epoch in rang:
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0
        # if args.leaveDataset == 2 and epoch == 0:
        #     dataloader.num_batches = dataloader.num_batches + start_epoch
        #     epoch += start_epoch

        # For each batch

        stateful = True  # flag that controls transmission of previous hidden states to current hidden states vectors.
        # Part of statefulness

        for batch in range( dataloader.num_batches):
            start = time.time()

            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True) ## shuffling input + stateless lstm

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]],d, args.distance_thresh)

                nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() #

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

                ## Modification : reset hidden and cell states only once after every batch ; keeping states updated during sequences 31st JAN

                # Define hidden states
                numNodes = nodes.size()[1]
                numObsNodes = obsNodes.size()[1]

                # if not stateful:
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                ## new update : 25th JAN , let the hidden state transition begin with negative ones
                ## such initialization did not lead to any new learning results, network converged quickly and loss did not decrease below
                # -6
                hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()

                cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                #  _ = \
                outputs, h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn, o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn, _ = \
                    net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1],
                    hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs,
                    obsNodes[:args.seq_length], obsEdges[:args.seq_length], obsNodesPresent[:-1],
                    obsEdgesPresent[:-1]
                        ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs, cell_states_obs_node_RNNs,
                    cell_states_obs_edge_RNNs)

                # # else:
                # #     if len(nodes) == len(hidden_states_node_RNNs): # no additional nodes introduced in graph
                # #         hidden_states_node_RNNs = Variable(h_node_rnn).cuda()
                # #         cell_states_node_RNNs = Variable(cell_node_rnn).cuda()
                # #     else: # for now number of nodes is only increasing in time as new pedestrians are detected in the scene
                # #         pad_size = len(nodes) - len(hidden_states_node_RNNs)
                #         cell_states_node_RNNs = Variable(np.pad(cell_node_rnn, pad_size)).cuda()
                # if sequence > 0:
                #     hidden_states_node_RNNs = Variable(h_node_rnn).resize(hidden_states_node_RNNs.cpu().size())
                #     hidden_states_edge_RNNs = h_edge_rnn
                #     cell_states_node_RNNs = cell_node_rnn
                #     cell_states_edge_RNNs = cell_edge_rnn
                    # new_num_nodes = h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0]
                    # if h_node_rnn.size()[0] - hidden_states_node_RNNs.size()[0] >=1:
                    #     np.pad(hidden_states_node_RNNs.cpu() , new_num_nodes, mode='constant').cuda()

                    # hidden_states_obs_node_RNNs = o_h_node_rnn
                    # hidden_states_obs_edge_RNNs = o_h_edge_rnn
                    # cell_states_obs_node_RNNs = o_cell_node_rnn
                    # cell_states_obs_edge_RNNs = o_cell_edge_rnn

                # Zero out the gradients

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]
                # Compute gradients
                loss.backward(retain_variables=True)
                param_log.write(str(net.alpha.data[0]) + '\n')

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print(
                '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch,
                                                                                    args.num_epochs * dataloader.num_batches,
                                                                                    epoch,
                                                                                    loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch)+','+str(loss_epoch)+',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0
        # dataloader.valid_num_batches = dataloader.valid_num_batches + start_epoch
        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)## stateless lstm without shuffling

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]], d, args.distance_thresh)

                nodes, edges, nodesPresent, edgesPresent,obsNodes, obsEdges, obsNodesPresent, obsEdgesPresent = stgraph.getSequence() #

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float()).cuda()
                edges = Variable(torch.from_numpy(edges).float()).cuda()

                obsNodes = Variable(torch.from_numpy(obsNodes).float()).cuda()
                obsEdges = Variable(torch.from_numpy(obsEdges).float()).cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()
                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size)).cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_human_edge_rnn_size)).cuda()

                numObsNodes = obsNodes.size()[1]
                hidden_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                hidden_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()

                cell_states_obs_node_RNNs = Variable(torch.zeros(numObsNodes, args.obs_node_rnn_size)).cuda()
                cell_states_obs_edge_RNNs = Variable(torch.zeros(numNodes * numNodes, args.human_obstacle_edge_rnn_size)).cuda()
                #

                outputs,  h_node_rnn, h_edge_rnn, cell_node_rnn, cell_edge_rnn,o_h_node_rnn ,o_h_edge_rnn, o_cell_node_rnn, o_cell_edge_rnn,  _= net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1],
                                             edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs,
                                             cell_states_node_RNNs, cell_states_edge_RNNs
                                             , obsNodes[:args.seq_length], obsEdges[:args.seq_length],
                                             obsNodesPresent[:-1], obsEdgesPresent[:-1]
                                             ,hidden_states_obs_node_RNNs, hidden_states_obs_edge_RNNs,
                                             cell_states_obs_node_RNNs, cell_states_obs_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)

                loss_batch += loss.data[0]

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + '\n')
        # Save the model after each epoch

        print('Saving model')
        torch.save({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, checkpoint_path(epoch))

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + ',' + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
    param_log.close()
Пример #4
0
def train(args):
    print("INPUT SEQUENCE LENGTH: {}".format(args.seq_length))
    print("OUTPUT SEQUENCE LENGTH: {}".format(args.pred_length))
    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            forcePreProcess=False)
    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = "../log/"

    # Logging file
    log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w")
    log_file = open(os.path.join(log_directory, "val.txt"), "w")

    # Save directory
    save_directory = "../../save_weight/"

    # Open the configuration file # 현재 argument 세팅 저장.
    with open(os.path.join(save_directory, "config.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, "4Dgraph.S-{}.P-{}.srnn_model_"\
                            .format(args.seq_length, args.pred_length)\
                            + str(x) + ".tar")

    # Initialize net
    net = SRNN(args)
    if args.use_cuda:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5)

    # learning_rate = args.learning_rate
    logging.info("Training begin")
    best_val_loss = 100
    best_epoch = 0
    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)  # Initialization
        loss_epoch = 0

        # For each batch
        # dataloader.num_batches = 10. 1 epoch have 10 batches
        for batch in range(dataloader.num_batches):
            start = time.time()
            # Get batch data, mini-batch
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(
                    dataloader.batch_size):  #미니 배치에 있는 각 sequence 데이터에 대한 처리.
                # Construct the graph for the current sequence in {nodes, edges}
                stgraph.readGraph([x[sequence]])
                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )  #미니 배치에 있는 각 sequence의 graph 정보.

                ##### Convert to cuda variables #####
                nodes = Variable(torch.from_numpy(nodes).float())
                # nodes[0] represent all the person(object)'s corrdinate show up in frame 0.
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]  #numNodes

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()

                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(  #NOTE: 0 for peds., 1 for Bic., 2 for Veh.
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                # Zero out the gradients // Initialization Step
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.item()
                # embed()
                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size  ##### NOTE: Expected Loss; E[L]
            loss_epoch += loss_batch

            logging.info(
                "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}".
                format(
                    epoch * dataloader.num_batches + batch,
                    args.num_epochs * dataloader.num_batches,
                    epoch,
                    loss_batch,
                    end - start,
                ))
        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",")

        #####################     Validation Part     #####################
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data

            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size  #### NOTE: Expected Loss; E[L]
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        # Record best epoch and best validation loss
        logging.info("(epoch {}), valid_loss = {:.3f}".format(
            epoch, loss_epoch))
        logging.info("Best epoch {}, Best validation loss {}".format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + "\n")

        # Save the model after each epoch
        logging.info("Saving model")
        torch.save(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            checkpoint_path(epoch),
        )

    # Record the best epoch and best validation loss overall
    logging.info("Best epoch {}, Best validation loss {}".format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + "," + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Пример #5
0
def train(args):
    # Construct the DataLoader object
    dataloader = DataLoader(args.batch_size,
                            args.seq_length + 1,
                            forcePreProcess=False)
    # Construct the ST-graph object
    stgraph = ST_GRAPH(1, args.seq_length + 1)

    # Log directory
    log_directory = os.getcwd() + "\\log\\"
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)

    # Logging file
    log_file_curve = open(os.path.join(log_directory, "log_curve.txt"), "w+")
    log_file = open(os.path.join(log_directory, "val.txt"), "w")

    # Save directory
    save_directory = os.getcwd() + "\\save\\"
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Open the configuration file
    with open(os.path.join(save_directory, "config.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, "srnn_model_" + str(x) + ".tar")

    # Initialize net
    net = SRNN(args)
    net.load_state_dict(torch.load('modelref/srnn_model_271.tar'),
                        strict=False)
    if args.use_cuda:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-5)

    # learning_rate = args.learning_rate
    logging.info("Training begin")
    best_val_loss = 100
    best_epoch = 0

    global plotter
    plotter = VisdomLinePlotter(env_name='main')

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        # dataloader.num_batches = 10. 1 epoch have 10 batches
        for batch in range(dataloader.num_batches):
            start = time.time()
            # Get batch data
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])
                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )
                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                # nodes[0] represent all the person's corrdinate show up in  frame 0.
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()

                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()
                # Forward prop
                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)
                loss_batch += loss.item()
                # embed()
                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            logging.info(
                "{}/{} (epoch {}), train_loss = {:.12f}, time/batch = {:.12f}".
                format(
                    epoch * dataloader.num_batches + batch,
                    args.num_epochs * dataloader.num_batches,
                    epoch,
                    loss_batch,
                    end - start,
                ))
        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        plotter.plot('loss', 'train', 'Class Loss', epoch, loss_epoch)
        # Log it
        log_file_curve.write(str(epoch) + "," + str(loss_epoch) + ",")

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data

            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence(
                )

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]

                hidden_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()

                hidden_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(
                    torch.zeros(numNodes, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(
                    torch.zeros(numNodes * numNodes, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                hidden_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_RNNs = hidden_states_super_node_RNNs.cuda(
                    )

                cell_states_super_node_RNNs = Variable(
                    torch.zeros(3, args.node_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_RNNs = cell_states_super_node_RNNs.cuda(
                    )

                hidden_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    hidden_states_super_node_Edge_RNNs = (
                        hidden_states_super_node_Edge_RNNs.cuda())

                cell_states_super_node_Edge_RNNs = Variable(
                    torch.zeros(3, args.edge_rnn_size))
                if args.use_cuda:
                    cell_states_super_node_Edge_RNNs = (
                        cell_states_super_node_Edge_RNNs.cuda())

                outputs, _, _, _, _, _, _, _, _, _ = net(
                    nodes[:args.seq_length],
                    edges[:args.seq_length],
                    nodesPresent[:-1],
                    edgesPresent[:-1],
                    hidden_states_node_RNNs,
                    hidden_states_edge_RNNs,
                    cell_states_node_RNNs,
                    cell_states_edge_RNNs,
                    hidden_states_super_node_RNNs,
                    hidden_states_super_node_Edge_RNNs,
                    cell_states_super_node_RNNs,
                    cell_states_super_node_Edge_RNNs,
                )

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:],
                                            nodesPresent[1:], args.pred_length)

                loss_batch += loss.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches

        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch

        plotter.plot('loss', 'val', 'Class Loss', epoch, loss_epoch)

        # Record best epoch and best validation loss
        logging.info("(epoch {}), valid_loss = {:.3f}".format(
            epoch, loss_epoch))
        logging.info("Best epoch {}, Best validation loss {}".format(
            best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch) + "\n")

        # Save the model after each epoch
        logging.info("Saving model")
        torch.save(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            checkpoint_path(epoch),
        )

    # Record the best epoch and best validation loss overall
    logging.info("Best epoch {}, Best validation loss {}".format(
        best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch) + "," + str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()
Пример #6
0
def train(args):

    # Construct the DataLoader object
    ## args: (batch_size=50, seq_length=5, datasets=[0, 1, 2, 3, 4, 5, 6], forcePreProcess=False, infer=False)
    dataloader = DataLoader(args.batch_size, args.seq_length + 1, args.train_dataset, forcePreProcess=True) ##** not sure why seq_length+1

    # Construct the ST-graph object
    ## args: (batch_size=50, seq_length=5)
    stgraph = ST_GRAPH(1, args.seq_length + 1)  ##**not sure why batch_size=1 and seq_length+1

    # Log directory
    log_directory = 'log/trainedOn_'+ str(args.train_dataset)
    if not os.path.exists(log_directory):
            os.makedirs(log_directory)

    # Logging file
    log_file_curve = open(os.path.join(log_directory, 'log_curve.txt'), 'w')
    log_file = open(os.path.join(log_directory, 'val.txt'), 'w')

    # Save directory for saving the model
    save_directory = 'save/trainedOn_'+str(args.train_dataset)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Open the configuration file
    ## store arguments from parser
    with open(os.path.join(save_directory, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Path to store the checkpoint file, i.e. model after the particular epoch
    def checkpoint_path(x):
        return os.path.join(save_directory, 'srnn_model_'+str(x)+'.tar')

    # Initialize net
    net = SRNN(args)
    if args.use_cuda:        
        net = net.cuda()

    # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate, momentum=0.0001, centered=True)
    optimizer = torch.optim.Adagrad(net.parameters())

    learning_rate = args.learning_rate
    print('Training begin')
    best_val_loss = 100
    best_epoch = 0

    # Training
    for epoch in range(args.num_epochs):
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data
            ## Format:
            ## x_batch:     input sequence of length self.seq_length
            ## y_batch:     output seq of same length shifted y 1 step in time
            ## frame_batch: frame IDs in the batch
            ## d:           current position of dataset pointer (points to the next batch to be loaded)
            x, _, _, d = dataloader.next_batch(randomUpdate=True)

            # Loss for this batch
            loss_batch = 0

            # For each sequence in the batch
            for sequence in range(dataloader.batch_size):
                # Construct the graph for the current sequence
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()

                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                # Zero out the gradients
                net.zero_grad()
                optimizer.zero_grad()

                # Forward prop
                outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1], hidden_states_node_RNNs, hidden_states_edge_RNNs, cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)
                loss_batch += loss.data[0]

                # Compute gradients
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

                # Reset the stgraph
                stgraph.reset()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

            print(
                '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch,
                                                                                    args.num_epochs * dataloader.num_batches,
                                                                                    epoch,
                                                                                    loss_batch, end - start))

        # Compute loss for the entire epoch
        loss_epoch /= dataloader.num_batches
        # Log it
        log_file_curve.write(str(epoch)+','+str(loss_epoch)+',')

        # Validation
        dataloader.reset_batch_pointer(valid=True)
        loss_epoch = 0

        for batch in range(dataloader.valid_num_batches):
            # Get batch data
            x, _, d = dataloader.next_valid_batch(randomUpdate=False)

            # Loss for this batch
            loss_batch = 0

            for sequence in range(dataloader.batch_size):
                stgraph.readGraph([x[sequence]])

                nodes, edges, nodesPresent, edgesPresent = stgraph.getSequence()

                # Convert to cuda variables
                nodes = Variable(torch.from_numpy(nodes).float())
                if args.use_cuda:
                    nodes = nodes.cuda()
                edges = Variable(torch.from_numpy(edges).float())
                if args.use_cuda:
                    edges = edges.cuda()

                # Define hidden states
                numNodes = nodes.size()[1]
                hidden_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    hidden_states_node_RNNs = hidden_states_node_RNNs.cuda()
                hidden_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    hidden_states_edge_RNNs = hidden_states_edge_RNNs.cuda()
                cell_states_node_RNNs = Variable(torch.zeros(numNodes, args.human_node_rnn_size))
                if args.use_cuda:
                    cell_states_node_RNNs = cell_states_node_RNNs.cuda()
                cell_states_edge_RNNs = Variable(torch.zeros(numNodes*numNodes, args.human_human_edge_rnn_size))
                if args.use_cuda:
                    cell_states_edge_RNNs = cell_states_edge_RNNs.cuda()

                outputs, _, _, _, _, _, _ = net(nodes[:args.seq_length], edges[:args.seq_length], nodesPresent[:-1], edgesPresent[:-1],
                                             hidden_states_node_RNNs, hidden_states_edge_RNNs,
                                             cell_states_node_RNNs, cell_states_edge_RNNs)

                # Compute loss
                loss = Gaussian2DLikelihood(outputs, nodes[1:], nodesPresent[1:], args.pred_length)

                loss_batch += loss.data.item()

                # Reset the stgraph
                stgraph.reset()

            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch

        loss_epoch = loss_epoch / dataloader.valid_num_batches


        #Saving the model\
        if loss_epoch < best_val_loss or args.save_every:
            print('Saving model')
            torch.save({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, checkpoint_path(epoch)) 

        #save_best_model overwriting the earlier file
        #if loss_epoch < best_val_loss:


        # Update best validation loss until now
        if loss_epoch < best_val_loss:
            best_val_loss = loss_epoch
            best_epoch = epoch
                      
        # Record best epoch and best validation loss
        print('(epoch {}), valid_loss = {:.3f}'.format(epoch, loss_epoch))
        print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
        # Log it
        log_file_curve.write(str(loss_epoch)+'\n')

        

    # Record the best epoch and best validation loss overall
    print('Best epoch {}, Best validation loss {}'.format(best_epoch, best_val_loss))
    # Log it
    log_file.write(str(best_epoch)+','+str(best_val_loss))

    # Close logging files
    log_file.close()
    log_file_curve.close()