Ejemplo n.º 1
0
def test(loadstate):

    if loadstate == True:
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        net.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']
        accu = checkpoint['accur']
    net.eval()
    imL = Variable(torch.FloatTensor(1).cuda())
    imR = Variable(torch.FloatTensor(1).cuda())
    dispL = Variable(torch.FloatTensor(1).cuda())

    dataset = sceneDisp('', 'test', tsfm)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1)
    data_iter = iter(dataloader)
    data = next(data_iter)

    randomH = np.random.randint(0, 160)
    randomW = np.random.randint(0, 400)
    print('test')
    imageL = data['imL'][:, :, randomH:(randomH + h), randomW:(randomW + w)]
    imageR = data['imR'][:, :, randomH:(randomH + h), randomW:(randomW + w)]
    disL = data['dispL'][:, :, randomH:(randomH + h), randomW:(randomW + w)]
    imL.resize_(imageL.size()).copy_(imageL)
    imR.resize_(imageR.size()).copy_(imageR)
    dispL.resize_(disL.size()).copy_(disL)
    loss_mul_list_test = []
    for d in range(maxdisp):
        loss_mul_temp = Variable(torch.Tensor(np.ones([1, 1, h, w]) *
                                              d)).cuda()
        loss_mul_list_test.append(loss_mul_temp)
    loss_mul_test = torch.cat(loss_mul_list_test, 1)

    with torch.no_grad():
        result = net(imL, imR)

    disp = torch.sum(result.mul(loss_mul_test), 1)
    diff = torch.abs(disp.cpu() - dispL.cpu())  # end-point-error

    accuracy = torch.sum(diff < 3) / float(h * w)
    print('test accuracy less than 3 pixels:%f' % accuracy)

    # save
    im = disp.cpu().numpy().astype('uint8')
    im = np.transpose(im, (1, 2, 0))
    cv2.imwrite('test_result.png', im, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
    gt = np.transpose(dispL[0, :, :, :].cpu().numpy(), (1, 2, 0))
    cv2.imwrite('test_gt.png', gt, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
    return disp
Ejemplo n.º 2
0
def train(epoch_total, loadstate):

    loss_mul_list = []
    for d in range(maxdisp):
        loss_mul_temp = Variable(torch.Tensor(np.ones([batch, 1, h, w]) *
                                              d)).cuda()
        loss_mul_list.append(loss_mul_temp)
    loss_mul = torch.cat(loss_mul_list, 1)

    optimizer = optim.RMSprop(net.parameters(), lr=0.001, alpha=0.9)
    dataset = sceneDisp('', 'train', tsfm)
    _, H, W = dataset.__getitem__(0)['imL'].shape
    loss_fn = nn.L1Loss()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch,
                                             shuffle=True,
                                             num_workers=1)

    imL = Variable(torch.FloatTensor(1).cuda())
    imR = Variable(torch.FloatTensor(1).cuda())
    dispL = Variable(torch.FloatTensor(1).cuda())

    loss_list = []
    start_epoch = 0

    writer = SummaryWriter()
    n_iter = 0

    if loadstate == True:
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        net.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']
        accu = checkpoint['accur']

    #print('startepoch:%d accuracy:%f' %(start_epoch,accu))

    for epoch in range(start_epoch, epoch_total):
        net.train()
        data_iter = iter(dataloader)

        #print('\nEpoch: %d' % epoch)
        train_loss = 0
        acc_total = 0
        for step in range(len(dataloader) - 1):
            #print('----epoch:%d------step:%d------' %(epoch,step))
            data = next(data_iter)

            randomH = np.random.randint(0, H - h - 1)
            randomW = np.random.randint(0, W - w - 1)
            imageL = data['imL'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            imageR = data['imR'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            disL = data['dispL'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            imL.resize_(imageL.size()).copy_(imageL)
            imR.resize_(imageR.size()).copy_(imageR)
            dispL.resize_(disL.size()).copy_(disL)
            #normalize
            # imgL=normalizeRGB(imL)
            # imgR=normalizeRGB(imR)

            net.zero_grad()
            optimizer.zero_grad()

            x = net(imL, imR)
            # print(x.shape)
            # print(loss_mul.shape)
            # print(net)
            result = torch.sum(x.mul(loss_mul), 1)
            result = result[:, None, :]
            # print(result.shape)
            #print_gpu_info()
            tt = loss_fn(result, dispL)
            #print_gpu_info()
            train_loss += tt.item()
            # tt = loss(x, loss_mul, dispL)
            tt.backward()
            optimizer.step()

            result = result.view(batch, 1, h, w)
            diff = torch.abs(result.cpu() - dispL.cpu())
            accuracy = torch.sum(diff < 3) / float(h * w * batch)
            acc_total += accuracy

            if step % show_n == (show_n - 1):
                writer.add_scalar('Loss/train_loss', train_loss / show_n,
                                  n_iter)
                #imL_ = unnormalize(imL[0])
                #disp_NET_ = result[0]
                #writer.add_image('Image/left', imL_, n_iter)
                #writer.add_image('Image/disparity', disp_NET_, n_iter)
                writer.close()

                n_iter += 1
                print('[%d, %5d, %5d] train_loss %.5f' %
                      (epoch + 1, step + 1, len(dataloader),
                       train_loss / show_n))
                train_loss = 0.0

            if (show):
                imL_ = unnormalize(imL[0]).permute(1, 2,
                                                   0).cpu().detach().numpy()
                imR_ = unnormalize(imR[0]).permute(1, 2,
                                                   0).cpu().detach().numpy()
                disp_TRUE_ = disL.cpu().detach().numpy()[0][0]
                disp_NET_ = result.cpu().detach().numpy()[0][0]
                plt.figure(figsize=(16, 8))
                plt.subplot(2, 2, 1)
                plt.imshow(imL_[..., ::-1])
                plt.subplot(2, 2, 2)
                plt.imshow(imR_[..., ::-1])
                plt.subplot(2, 2, 3)
                plt.imshow(disp_TRUE_, cmap='rainbow', vmin=0, vmax=maxdisp)
                plt.colorbar()
                plt.subplot(2, 2, 4)
                plt.imshow(disp_NET_, cmap='rainbow', vmin=0, vmax=maxdisp)
                plt.colorbar()
                plt.show()

            #print('====accuracy for the result less than 3 pixels===:%f' %accuracy)
            #print('====average accuracy for the result less than 3 pixels===:%f' % (acc_total/(step+1)))

            # save
            if step % 1000 == 0:
                state = {
                    'net': net.state_dict(),
                    'step': step,
                    'loss_list': loss_list,
                    'epoch': epoch,
                    'accur': acc_total
                }
                torch.save(state, 'checkpoint/ckpt.t7')
    fp.close()
Ejemplo n.º 3
0
def train(epoch_total, loadstate):

    loss_mul_list = []
    for d in range(maxdisp):
        loss_mul_temp = Variable(torch.Tensor(np.ones([batch, 1, h, w]) *
                                              d)).cuda()
        loss_mul_list.append(loss_mul_temp)
    loss_mul = torch.cat(loss_mul_list, 1)

    optimizer = optim.RMSprop(net.parameters(), lr=0.001, alpha=0.9)
    dataset = sceneDisp('', 'train', tsfm)
    loss_fn = nn.L1Loss()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch,
                                             shuffle=True,
                                             num_workers=1)

    imL = Variable(torch.FloatTensor(1).cuda())
    imR = Variable(torch.FloatTensor(1).cuda())
    dispL = Variable(torch.FloatTensor(1).cuda())

    loss_list = []
    print(len(dataloader))
    start_epoch = 0
    if loadstate == True:
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        net.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']
        accu = checkpoint['accur']
    print('startepoch:%d accuracy:%f' % (start_epoch, accu))
    for epoch in range(start_epoch, epoch_total):
        net.train()
        data_iter = iter(dataloader)

        print('\nEpoch: %d' % epoch)
        train_loss = 0
        acc_total = 0
        for step in range(len(dataloader) - 1):
            print('----epoch:%d------step:%d------' % (epoch, step))
            data = next(data_iter)

            randomH = np.random.randint(0, 160)
            randomW = np.random.randint(0, 400)
            imageL = data['imL'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            imageR = data['imR'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            disL = data['dispL'][:, :, randomH:(randomH + h),
                                 randomW:(randomW + w)]
            imL.data.resize_(imageL.size()).copy_(imageL)
            imR.data.resize_(imageR.size()).copy_(imageR)
            dispL.data.resize_(disL.size()).copy_(disL)
            #normalize
            # imgL=normalizeRGB(imL)
            # imgR=normalizeRGB(imR)

            net.zero_grad()
            optimizer.zero_grad()

            x = net(imL, imR)
            # print(x.shape)
            # print(loss_mul.shape)
            # print(net)
            result = torch.sum(x.mul(loss_mul), 1)
            # print(result.shape)
            tt = loss_fn(result, dispL)
            train_loss += tt.data
            # tt = loss(x, loss_mul, dispL)
            tt.backward()
            optimizer.step()
            print('=======loss value for every step=======:%f' % (tt.data))
            print('=======average loss value for every step=======:%f' %
                  (train_loss / (step + 1)))
            result = result.view(batch, 1, h, w)
            diff = torch.abs(result.data.cpu() - dispL.data.cpu())
            print(diff.shape)
            accuracy = torch.sum(diff < 3) / float(h * w * batch)
            acc_total += accuracy
            print('====accuracy for the result less than 3 pixels===:%f' %
                  accuracy)
            print(
                '====average accuracy for the result less than 3 pixels===:%f'
                % (acc_total / (step + 1)))

            # save
            if step % 100 == 0:
                loss_list.append(train_loss / (step + 1))
            if (step > 1 and step % 200 == 0) or step == len(dataloader) - 2:
                print('=======>saving model......')
                state = {
                    'net': net.state_dict(),
                    'step': step,
                    'loss_list': loss_list,
                    'epoch': epoch,
                    'accur': acc_total
                }
                torch.save(state, 'checkpoint/ckpt.t7')

                im = result[0, :, :, :].data.cpu().numpy().astype('uint8')
                im = np.transpose(im, (1, 2, 0))
                cv2.imwrite('train_result.png', im,
                            [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
                gt = np.transpose(dispL[0, :, :, :].data.cpu().numpy(),
                                  (1, 2, 0))
                cv2.imwrite('train_gt.png', gt,
                            [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

    fp = open('loss.txt', 'w')
    for i in range(len(loss_list)):
        fp.write(str(loss_list[i][0]))
        fp.write('\n')
    fp.close()