Ejemplo n.º 1
0
def main(args):
    test_path = get_dset_path(args.dataset_name, 'test')

    logger.info("Initializing test dataset")
    test_dset, test_loader = data_loader(args, test_path)

    net = LSTM_model(args)
    net = net.cuda()

    checkpoint_path = ".\model\lstm767.tar"
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['state_dict'])
    net.eval()

    batch_error = 0
    batch_fde = 0
    for idx, batch in enumerate(test_loader):

        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
         loss_mask, seq_start_end) = batch
        num_ped = obs_traj.size(1)   # (8 n 2)
        pred_traj_gt = pred_traj_gt.cuda()
        pred_traj = net(obs_traj.cuda(), num_ped, pred_traj_gt)

        ade_1 = get_mean_error(pred_traj, pred_traj_gt)
        ade_2 = displacement_error(pred_traj, pred_traj_gt) / (pred_traj.size(1) * 12)
        fde = final_displacement_error(pred_traj, pred_traj_gt) / pred_traj.size(1)

        batch_error += ade_2
        batch_fde += fde
    ade = batch_error / (idx+1)
    fin_fde = batch_fde / (idx+1)
    logger.info("ade is {:.2f}".format(ade))
    logger.info("ade is {:.2f}".format(fin_fde))
Ejemplo n.º 2
0
def main(args):
    train_path = get_dset_path(args.dataset_name, 'train')
    val_path = get_dset_path(args.dataset_name, 'val')

    # 随机种子
    # torch.manual_seed(2)
    # np.random.seed(2)
    # if args.use_gpu:
    #     torch.cuda.manual_seed_all(2)

    logger.info("Initializing train dataset")
    train_dset, train_loader = data_loader(args, train_path)
    logger.info("Initializing val dataset")
    _, val_loader = data_loader(args, val_path)

    log_path = './log/'
    log_file_curve = open(os.path.join(log_path, 'log_loss.txt'), 'w+')
    log_file_curve_val = open(os.path.join(log_path, 'log_loss_val.txt'), 'w+')
    log_file_curve_val_ade = open(
        os.path.join(log_path, 'log_loss_val_ade.txt'), 'w+')

    net = LSTM_model(args)
    if args.use_gpu:
        net = net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

    #接着上次训练的地方继续训练
    # restore_path = '.\model\lstm294.tar'
    # logger.info('Restoring from checkpoint {}'.format(restore_path))
    # checkpoint = torch.load(restore_path)
    # net.load_state_dict(checkpoint['state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #
    # for i_epoch in range(checkpoint['epoch']+1):
    #     if (i_epoch + 1) % 100 == 0:
    #         args.learning_rate *= 0.98

    epoch_loss_min = 160
    epoch_smallest = 0
    #for epoch in range(checkpoint['epoch']+1, args.num_epochs):
    for epoch in range(args.num_epochs):
        count = 0
        batch_loss = 0

        for batch in train_loader:
            # Zero out gradients
            net.zero_grad()
            optimizer.zero_grad()

            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
             non_linear_ped, loss_mask, seq_start_end) = batch
            num_ped = obs_traj.size(1)
            pred_traj_gt = pred_traj_gt

            #model_teacher.py
            pred_traj = net(obs_traj, num_ped, pred_traj_gt, seq_start_end)
            loss = displacement_error(pred_traj, pred_traj_gt)
            #loss = get_mean_error(pred_traj, pred_traj_gt)

            # Compute gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
            # Update parameters
            optimizer.step()

            batch_loss += loss
            count += 1

            #print(loss / num_ped)
        if (epoch + 1) % 6 == 0:
            pass
            #scheduler.step()
        logger.info('epoch {} train loss is {}'.format(epoch,
                                                       batch_loss / count))
        log_file_curve.write(str(batch_loss.item() / count) + "\n")

        batch_loss = 0
        val_ade = 0
        total_ade = 0
        for idx, batch in enumerate(val_loader):
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
             non_linear_ped, loss_mask, seq_start_end) = batch
            num_ped = obs_traj.size(1)
            pred_traj_gt = pred_traj_gt

            # model_teacher.py
            pred_traj = net(obs_traj, num_ped, pred_traj_gt, seq_start_end)
            loss = displacement_error(pred_traj, pred_traj_gt)

            batch_loss += loss
            val_ade += loss / (num_ped * 12)
            total_ade += val_ade

            count += 1

        fin_ade = total_ade / (idx + 1)
        log_file_curve_val_ade.write(str(fin_ade.item()) + "\n")

        epoch_loss = batch_loss / count
        if epoch_loss_min > epoch_loss:
            epoch_loss_min = epoch_loss
            epoch_smallest = epoch

            logger.info('Saving model')
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, checkpoint_path(epoch))
        logger.info('epoch {} val loss is {}'.format(epoch, epoch_loss))
        log_file_curve_val.write(str(epoch_loss.item()) + "\n")
        logger.info('epoch {} is smallest loss is {}'.format(
            epoch_smallest, epoch_loss_min))
        logger.info('the smallest ade is {}'.format(total_ade / (idx + 1)))
        logger.info("-" * 50)