Пример #1
0
def main(args):
    train_transform = T.Compose([RandomCrop([256, 512]), Normalize(mean, std), ToTensor()])
    train_dataset = KITTI2015(args.datadir, mode='train', transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    validate_transform = T.Compose([Normalize(mean, std), ToTensor(), Pad(384, 1248)])
    validate_dataset = KITTI2015(args.datadir, mode='validate', transform=validate_transform)
    validate_loader = DataLoader(validate_dataset, batch_size=args.validate_batch_size, num_workers=args.num_workers)
    step = 0
    best_error = 100.0
    model = PSMNet(args.maxdisp).to(device)
    model = nn.DataParallel(model, device_ids=device_ids)  # 模型并行运行
    criterion = SmoothL1Loss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.model_path is not None:
        state = torch.load(args.model_path)
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        step = state['step']
        best_error = state['error']
        print('load model from {}'.format(args.model_path))
    else:
        print('66666')
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
    for epoch in range(1, args.num_epochs + 1):
        time_start = time.time()
        model.eval()
        error = validate(model, validate_loader, epoch)
        # best_error = save(model, optimizer, epoch, step, error, best_error)
        time_end = time.time()
        print('该epoch运行时间:', time_end - time_start, '秒')
Пример #2
0
def main(args):
    # 1. 加载训练集, 验证集数据
    train_transform = T.Compose(
        [RandomCrop([256, 512]),
         Normalize(mean, std),
         ToTensor()])
    train_dataset = KITTI2015(args.datadir,
                              mode='train',
                              transform=train_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)

    validate_transform = T.Compose(
        [Normalize(mean, std),
         ToTensor(), Pad(384, 1248)])
    validate_dataset = KITTI2015(args.datadir,
                                 mode='validate',
                                 transform=validate_transform)
    validate_loader = DataLoader(validate_dataset,
                                 batch_size=args.validate_batch_size,
                                 num_workers=args.num_workers)

    step = 0
    best_error = 100.0

    model = PSMNet(args.maxdisp).to(device)
    # model = nn.DataParallel(model, device_ids=device_ids)  # 多gpu运行

    criterion = SmoothL1Loss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.model_path is not None:
        # 如果模型路径不空 我们就加载模型
        state = torch.load(args.model_path)
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        step = state['step']
        best_error = state['error']
        print('load model from {}'.format(args.model_path))

    # 打印出模型的参数了量
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # 开始训练
    for epoch in range(1, args.num_epochs + 1):
        model.train()
        step = train(model, train_loader, optimizer, criterion, step)
        adjust_lr(optimizer, epoch)  # 不断进行学习率的调整

        if epoch % args.save_per_epoch == 0:
            model.eval()
            error = validate(model, validate_loader, epoch)
            best_error = save(model, optimizer, epoch, step, error, best_error)
Пример #3
0
def main():
    '''
    测试
    :return:
    '''
    left = cv2.imread(args.left)
    right = cv2.imread(args.right)

    pairs = {'left': left, 'right': right}

    transform = T.Compose([Normalize(mean, std), ToTensor(), Pad(384, 1248)])
    pairs = transform(pairs)
    left = pairs['left'].to(device).unsqueeze(0)
    right = pairs['right'].to(device).unsqueeze(0)

    model = PSMNet(args.maxdisp).to(device)
    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    state = torch.load(args.model_path)
    if len(device_ids) == 1:
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state['state_dict'].items():
            namekey = k[7:]  # remove `module.`
            new_state_dict[namekey] = v
        state['state_dict'] = new_state_dict

    model.load_state_dict(state['state_dict'])
    print('load model from {}'.format(args.model_path))
    print('epoch: {}'.format(state['epoch']))
    print('3px-error: {}%'.format(state['error']))

    model.eval()

    with torch.no_grad():
        _, _, disp = model(left, right)

    disp = disp.squeeze(0).detach().cpu().numpy()
    plt.figure(figsize=(12.84, 3.84))
    plt.axis('off')
    plt.imshow(disp)
    plt.colorbar()
    plt.savefig(args.save_path, dpi=100)

    print('save diparity map in {}'.format(args.save_path))