예제 #1
0
            # Forward pass
            if params.use_maneuvers:
                if epoch_num < pretrainEpochs:
                    # During pre-training with MSE loss, validate with MSE for true maneuver class trajectory
                    net.train_flag = True
                    fut_pred, _, _ = net(hist, nbrs, mask, lat_enc, lon_enc,
                                         hist_grid)
                    l = maskedMSE(fut_pred, fut, op_mask)
                else:
                    # During training with NLL loss, validate with NLL over multi-modal distribution
                    fut_pred, lat_pred, lon_pred = net(hist, nbrs, mask,
                                                       lat_enc, lon_enc,
                                                       hist_grid)
                    l = maskedNLLTest(fut_pred,
                                      lat_pred,
                                      lon_pred,
                                      fut,
                                      op_mask,
                                      avg_along_time=True)
                    avg_val_lat_acc += (torch.sum(
                        torch.max(lat_pred.data, 1)[1] == torch.max(
                            lat_enc.data, 1)[1])).item() / lat_enc.size()[0]
                    avg_val_lon_acc += (torch.sum(
                        torch.max(lon_pred.data, 1)[1] == torch.max(
                            lon_enc.data, 1)[1])).item() / lon_enc.size()[0]
            else:
                fut_pred = net(hist, nbrs, mask, lat_enc, lon_enc, hist_grid)
                if epoch_num < pretrainEpochs:
                    l = maskedMSE(fut_pred, fut, op_mask)
                else:
                    l = maskedNLL(fut_pred, fut, op_mask)
def train_model():
    args = parser.parse_args()
    print("------------- {} -------------".format(args.name))
    print("Batch size : {}".format(args.batch_size))
    print("Learning rate : {}".format(args.learning_rate))
    print("Use Planning Coupled: {}".format(args.use_planning))
    print("Use Target Fusion: {}".format(args.use_fusion))

    ## Initialize network and optimizer
    PiP = pipNet(args)
    if args.use_cuda:
        PiP = PiP.cuda()
    optimizer = torch.optim.Adam(PiP.parameters(), lr=args.learning_rate)
    crossEnt = torch.nn.BCELoss()

    ## Initialize the log folder
    log_path = "./trained_models/{}/".format(args.name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if args.tensorboard:
        logger = SummaryWriter(log_path + 'train-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))
        logger_val = SummaryWriter(log_path + 'validation-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))

    ## Initialize training parameters
    pretrainEpochs = args.pretrain_epochs
    trainEpochs    = args.train_epochs
    batch_size     = args.batch_size

    ## Initialize data loaders
    print("Train dataset: {}".format(args.train_set))
    trSet = highwayTrajDataset(path=args.train_set,
                         targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
                         grid_size=args.grid_size,
                         fit_plan_traj=False)
    print("Validation dataset: {}".format(args.val_set))
    valSet = highwayTrajDataset(path=args.val_set,
                          targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
                          grid_size=args.grid_size,
                          fit_plan_traj=True)
    trDataloader =  DataLoader(trSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=trSet.collate_fn)
    valDataloader = DataLoader(valSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=valSet.collate_fn)
    print("DataSet Prepared : {} train data, {} validation data\n".format(len(trSet), len(valSet)))
    print("Network structure: {}\n".format(PiP))

    ## Training process
    for epoch_num in range( pretrainEpochs + trainEpochs ):
        if epoch_num == 0:
            print('Pretrain with MSE loss')
        elif epoch_num == pretrainEpochs:
            print('Train with NLL loss')
        ## Variables to track training performance:
        avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0
        ## Training status, reclaim after each epoch
        PiP.train()
        PiP.train_output_flag = True
        for i, data in enumerate(trDataloader):
            st_time = time.time()
            nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data
            if args.use_cuda:
                nbsHist = nbsHist.cuda()
                nbsMask = nbsMask.cuda()
                planFut = planFut.cuda()
                planMask = planMask.cuda()
                targsHist = targsHist.cuda()
                targsEncMask = targsEncMask.cuda()
                lat_enc = lat_enc.cuda()
                lon_enc = lon_enc.cuda()
                targsFut = targsFut.cuda()
                targsFutMask = targsFutMask.cuda()

            # Forward pass
            fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc)
            if epoch_num < pretrainEpochs:
                # Pre-train with MSE loss to speed up training
                l = maskedMSE(fut_pred, targsFut, targsFutMask)
            else:
                # Train with NLL loss
                l = maskedNLL(fut_pred, targsFut, targsFutMask) + crossEnt(lat_pred, lat_enc) + crossEnt(lon_pred, lon_enc)

            # Back-prop and update weights
            optimizer.zero_grad()
            l.backward()
            prev_vec_norm = torch.nn.utils.clip_grad_norm_(PiP.parameters(), 10)
            optimizer.step()

            # Track average train loss and average train time:
            batch_time = time.time()-st_time
            avg_loss_tr += l.item()
            avg_time_tr += batch_time

            # For every 100 batches: record loss, validate model, and plot.
            if i%100 == 99:
                eta = avg_time_tr/100*(len(trSet)/batch_size-i)
                epoch_progress = i * batch_size / len(trSet)
                print("Epoch no:",epoch_num+1,
                    "| Epoch progress(%):",format(epoch_progress*100,'0.2f'),
                    "| Avg train loss:",format(avg_loss_tr/100,'0.2f'),
                    "| ETA(s):",int(eta))

                if args.tensorboard:
                    logger.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_tr / 100, (epoch_progress + epoch_num) * 100)

                ## Validatation during training:
                eval_batch_num = 20
                with torch.no_grad():
                    PiP.eval()
                    PiP.train_output_flag = False
                    for i, data in enumerate(valDataloader):
                        nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data
                        if args.use_cuda:
                            nbsHist = nbsHist.cuda()
                            nbsMask = nbsMask.cuda()
                            planFut = planFut.cuda()
                            planMask = planMask.cuda()
                            targsHist = targsHist.cuda()
                            targsEncMask = targsEncMask.cuda()
                            lat_enc = lat_enc.cuda()
                            lon_enc = lon_enc.cuda()
                            targsFut = targsFut.cuda()
                            targsFutMask = targsFutMask.cuda()
                        if epoch_num < pretrainEpochs:
                            # During pre-training with MSE loss, validate with MSE for true maneuver class trajectory
                            PiP.train_output_flag = True
                            fut_pred, _, _ = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask,
                                                 lat_enc, lon_enc)
                            l = maskedMSE(fut_pred, targsFut, targsFutMask)
                        else:
                            # During training with NLL loss, validate with NLL over multi-modal distribution
                            fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist,
                                                               targsEncMask, lat_enc, lon_enc)
                            l = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask, avg_along_time=True)
                        avg_loss_val += l.item()
                        if i==(eval_batch_num-1):
                            if args.tensorboard:
                                logger_val.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_val / eval_batch_num, (epoch_progress + epoch_num) * 100)
                            break
                # Clear statistic
                avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0
                # Revert to train mode after in-process evaluation.
                PiP.train()
                PiP.train_output_flag = True

        ## Save the model after each epoch______________________________________________________________________________
        epoCount = epoch_num + 1
        if epoCount < pretrainEpochs:
            torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, epoCount, 0))
        else:
            torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, pretrainEpochs, epoCount - pretrainEpochs))

    # All epochs finish________________________________________________________________________________________________
    torch.save(PiP.state_dict(), log_path+"{}.tar".format(args.name))
    print("Model saved in trained_models/{}/{}.tar\n".format(args.name, args.name))
for i, data in enumerate(tsDataloader):
    st_time = time.time()
    hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, graph_list, _, _, nbrs_idx = data

    # Initialize Variables
    if args['use_cuda']:
        hist = hist.to(args['device'])
        nbrs = nbrs.to(args['device'])
        mask = mask.to(args['device'])
        fut = fut.to(args['device'])
        op_mask = op_mask.to(args['device'])
        nbrs_idx = nbrs_idx.to(args['device'])

    if metric == 'nll':
        # Forward pass
        fut_pred = net(hist, nbrs, mask, graph_list, nbrs_idx, mode='path')
        l, c = maskedNLLTest(fut_pred, 0, 0, fut, op_mask, use_maneuvers=False)
    else:
        # Forward pass
        fut_pred = net(hist, nbrs, mask, graph_list, nbrs_idx, mode='path')
        l, c = maskedMSETest(fut_pred, fut, op_mask)

    lossVals += l.detach()
    counts += c.detach()

if metric == 'nll':
    print(lossVals / counts)
else:
    print(torch.pow(lossVals / counts, 0.5) *
          0.3048)  # Calculate RMSE and convert from feet to meters
예제 #4
0
def model_evaluate():

    args = parser.parse_args()

    ## Initialize network
    PiP = pipNet(args)
    PiP.load_state_dict(
        torch.load('./trained_models/{}/{}.tar'.format(
            (args.name).split('-')[0], args.name)))
    if args.use_cuda:
        PiP = PiP.cuda()

    ## Evaluation Mode
    PiP.eval()
    PiP.train_output_flag = False
    initLogging(log_file='./trained_models/{}/evaluation.log'.format((
        args.name).split('-')[0]))

    ## Intialize dataset
    logging.info("Loading test data from {}...".format(args.test_set))
    tsSet = highwayTrajDataset(path=args.test_set,
                               targ_enc_size=args.social_context_size +
                               args.dynamics_encoding_size,
                               grid_size=args.grid_size,
                               fit_plan_traj=True,
                               fit_plan_further_ds=args.plan_info_ds)
    logging.info("TOTAL :: {} test data.".format(len(tsSet)))
    tsDataloader = DataLoader(tsSet,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=tsSet.collate_fn)

    ## Loss statistic
    logging.info(
        "<{}> evaluated by {}-based NLL & RMSE, with planning input of {}s step."
        .format(args.name, args.metric, args.plan_info_ds * 0.2))
    if args.metric == 'agent':
        nll_loss_stat = np.zeros(
            (np.max(tsSet.Data[:, 0]).astype(int) + 1,
             np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1,
             args.out_length))
        rmse_loss_stat = np.zeros(
            (np.max(tsSet.Data[:, 0]).astype(int) + 1,
             np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1,
             args.out_length))
        both_count_stat = np.zeros(
            (np.max(tsSet.Data[:, 0]).astype(int) + 1,
             np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1,
             args.out_length))
    elif args.metric == 'sample':
        rmse_loss = torch.zeros(25).cuda()
        rmse_counts = torch.zeros(25).cuda()
        nll_loss = torch.zeros(25).cuda()
        nll_counts = torch.zeros(25).cuda()
    else:
        raise RuntimeError("Wrong type of evaluation metric is specified")
    avg_eva_time = 0

    ## Evaluation process
    with torch.no_grad():
        for i, data in enumerate(tsDataloader):
            st_time = time.time()
            nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, idxs = data
            # Initialize Variables
            if args.use_cuda:
                nbsHist = nbsHist.cuda()
                nbsMask = nbsMask.cuda()
                planFut = planFut.cuda()
                planMask = planMask.cuda()
                targsHist = targsHist.cuda()
                targsEncMask = targsEncMask.cuda()
                lat_enc = lat_enc.cuda()
                lon_enc = lon_enc.cuda()
                targsFut = targsFut.cuda()
                targsFutMask = targsFutMask.cuda()

            # Inference
            fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut,
                                               planMask, targsHist,
                                               targsEncMask, lat_enc, lon_enc)

            # Performance metric
            if args.metric == 'agent':
                dsIDs, targsIDs = tsSet.batchTargetVehsInfo(idxs)
                l, c = maskedNLLTest(fut_pred,
                                     lat_pred,
                                     lon_pred,
                                     targsFut,
                                     targsFutMask,
                                     separately=True)
                # Select the trajectory with the largest probability of maneuver label when evaluating by RMSE
                fut_pred_max = torch.zeros_like(fut_pred[0])
                for k in range(lat_pred.shape[0]):
                    lat_man = torch.argmax(lat_pred[k, :]).detach()
                    lon_man = torch.argmax(lon_pred[k, :]).detach()
                    indx = lon_man * 3 + lat_man
                    fut_pred_max[:, k, :] = fut_pred[indx][:, k, :]
                # Using the most probable trajectory
                ll, cc = maskedMSETest(fut_pred_max,
                                       targsFut,
                                       targsFutMask,
                                       separately=True)
                l = l.detach().cpu().numpy()
                ll = ll.detach().cpu().numpy()
                c = c.detach().cpu().numpy()
                cc = cc.detach().cpu().numpy()
                for j, targ in enumerate(targsIDs):
                    dsID = dsIDs[j]
                    nll_loss_stat[dsID, targ, :] += l[:, j]
                    rmse_loss_stat[dsID, targ, :] += ll[:, j]
                    both_count_stat[dsID, targ, :] += c[:, j]
            elif args.metric == 'sample':
                l, c = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut,
                                     targsFutMask)
                nll_loss += l.detach()
                nll_counts += c.detach()
                fut_pred_max = torch.zeros_like(fut_pred[0])
                for k in range(lat_pred.shape[0]):
                    lat_man = torch.argmax(lat_pred[k, :]).detach()
                    lon_man = torch.argmax(lon_pred[k, :]).detach()
                    indx = lon_man * 3 + lat_man
                    fut_pred_max[:, k, :] = fut_pred[indx][:, k, :]
                l, c = maskedMSETest(fut_pred_max, targsFut, targsFutMask)
                rmse_loss += l.detach()
                rmse_counts += c.detach()

            # Time estimate
            batch_time = time.time() - st_time
            avg_eva_time += batch_time
            if i % 100 == 99:
                eta = avg_eva_time / 100 * (len(tsSet) / args.batch_size - i)
                logging.info("Evaluation progress(%):{:.2f}".format(
                    i / (len(tsSet) / args.batch_size) * 100, ) +
                             " | ETA(s):{}".format(int(eta)))
                avg_eva_time = 0

    # Result Summary
    if args.metric == 'agent':
        # Loss averaged from all predicted vehicles.
        ds_ids, veh_ids = both_count_stat[:, :, 0].nonzero()
        num_vehs = len(veh_ids)
        rmse_loss_averaged = np.zeros((args.out_length, num_vehs))
        nll_loss_averaged = np.zeros((args.out_length, num_vehs))
        count_averaged = np.zeros((args.out_length, num_vehs))
        for i in range(num_vehs):
            count_averaged[:, i] = \
                both_count_stat[ds_ids[i], veh_ids[i], :].astype(bool)
            rmse_loss_averaged[:,i] = rmse_loss_stat[ds_ids[i], veh_ids[i], :] \
                                      * count_averaged[:, i] / (both_count_stat[ds_ids[i], veh_ids[i], :] + 1e-9)
            nll_loss_averaged[:,i]  = nll_loss_stat[ds_ids[i], veh_ids[i], :] \
                                      * count_averaged[:, i] / (both_count_stat[ds_ids[i], veh_ids[i], :] + 1e-9)
        rmse_loss_sum = np.sum(rmse_loss_averaged, axis=1)
        nll_loss_sum = np.sum(nll_loss_averaged, axis=1)
        count_sum = np.sum(count_averaged, axis=1)
        rmseOverall = np.power(
            rmse_loss_sum / count_sum,
            0.5) * 0.3048  # Unit converted from feet to meter.
        nllOverall = nll_loss_sum / count_sum
    elif args.metric == 'sample':
        rmseOverall = (torch.pow(rmse_loss / rmse_counts, 0.5) * 0.3048).cpu()
        nllOverall = (nll_loss / nll_counts).cpu()

    # Print the metrics every 5 time frame (1s)
    logging.info("RMSE (m)\t=> {}, Mean={:.3f}".format(
        rmseOverall[4::5], rmseOverall[4::5].mean()))
    logging.info("NLL (nats)\t=> {}, Mean={:.3f}".format(
        nllOverall[4::5], nllOverall[4::5].mean()))