import cv2
import py_normals

im = cv2.imread("/Users/bdol/code/rgbd-photometric/rgbd-util/ein_0_med_frames/depth/depth_00000.png", -1)
pcloud = py_normals.depth_to_world(im)
nRows = pcloud.shape[0]
nCols = pcloud.shape[1]

points = []
for i in range(0, nRows):
    for j in range(0, nCols):
        x = pcloud[i, j, 0]
        y = pcloud[i, j, 1]
        z = pcloud[i, j, 2]
        valid = 0
        if z>0 and z<360:
            valid = 1
        points.append([x, y, -z, valid])
        

ply = open("plytest.ply", "w")
ply.write("ply\n")
ply.write("format ascii 1.0\n")
ply.write("obj_info is_mesh 0\n")
ply.write("obj_info num_cols "+str(nCols)+"\n")
ply.write("obj_info num_rows "+str(nRows)+"\n")
ply.write("element vertex "+str(nCols*nRows)+"\n")
ply.write("property float x\n")
ply.write("property float y\n")
ply.write("property float z\n")
ply.write("property uchar diffuse_red\n")
def main():
    video_dir = r'C:\Projects\GitHub\rgbd-photometric\rgbd-util\daniel'
    depth_dir = os.path.join(video_dir, 'depth')
    rgb_dir = os.path.join(video_dir, 'rgb')

    depth_image_filenames = glob.glob(os.path.join(depth_dir, '*.png'))
    rgb_image_filenames = glob.glob(os.path.join(rgb_dir, '*.png'))
    depth_image_filenames = depth_image_filenames[0:10]

    im = cv2.imread(depth_image_filenames[0], -1)
    im[im>560] = 0
    width = 223
    height = 307
    roi = [185, 45, width, height]

    im = apply_roi(im, roi)

    pcloud = py_normals.depth_to_world(im)
    ref_normals, valid = py_normals.crossprod_normals(pcloud)
    flat_ref_normals = flatten_normals(ref_normals)

    #cv2.imshow("Normals", ref_normals)
    #cv2.imshow("valid", valid*numpy.ones(valid.shape))
    #cv2.waitKey(0)

    M = make_M(rgb_image_filenames, roi)
    L, N = solve_for_L_and_N(M,3)
    #fit_local_model(M, L, N)
    flat_valid = valid.reshape(-1)
    A = solve_for_A(flat_ref_normals[:,flat_valid], N[:,flat_valid])
    N = numpy.dot(A,N)

    for i in range(N.shape[1]):
        N[:,i] /= (1e-5 + numpy.linalg.norm(N[:,i]))

    normals_image = numpy.zeros((height, width, 3))
    for i in range(height):
        for j in range(width):
            idx = i*width+j
            normals_image[i, j, :] = N[:,idx]
            normals_image[i, j, :] /= (1e-5 + numpy.linalg.norm(normals_image[i,j,:]))

    #cv2.imshow('normals', normals_image-ref_normals)
    #cv2.waitKey()

   

    normals_image = numpy.zeros((height, width, 3))
    for i in range(height):
        for j in range(width):
            idx = i*width+j
            normals_image[i, j, :] = N[:,idx]

    #cv2.imshow('normals', normals_image)
    #cv2.waitKey()
    depth = integrate_normals(normals_image, pcloud[:,:,2]>0, pcloud[:,:,2], 0.1)

    t = solve_bas_relief(pcloud[:,:,2], depth, pcloud[:,:,2]>0)
    #t = [0, 0, 1, 0]
    #t = solve_bas_relief(pcloud[:,:,2], depth, valid)

    indices = -1*numpy.ones((height, width), numpy.int32)
    out = open('test.obj', 'w')
    out2 = open('depth.obj', 'w')
    count = 0
    for i in range(height):
        for j in range(width):
            x, y = pcloud[i,j,0:2]
            valid[i,j] = numpy.linalg.norm(normals_image[i,j,:]) > 0.5
            if valid[i,j]:
                x = (j + roi[0] - 640/2)*depth[i,j]/535
                y = (i + roi[1] - 480/2)*depth[i,j]/535
                out.write('v {0} {1} {2}\n'.format(x, -y,-(t[0]*j + t[1]*i + t[2]*depth[i,j] + t[3])))
                #if pcloud[i,j,2] != 0: 
                out2.write('v {0} {1} {2}\n'.format(x, -y,-pcloud[i,j,2]))
                #out.write('v {0} {1} {2}\n'.format(pcloud[i,j,0], pcloud[i,j,1],-pcloud[i,j,2]))
                indices[i,j] = count
                count += 1


    for i in range(height):
        for j in range(width):
            if valid[i,j] and i + 1 < height and j + 1 < width:
                idx1 = indices[i,j]
                idx2 = indices[i+1,j]
                idx3 = indices[i, j+1]
                idx4 = indices[i+1,j+1]
                if idx4 >= 0:
                    if idx2 >= 0:
                        out.write('f {0} {1} {2}\n'.format(idx4+1, idx1+1, idx2+1))
                        out2.write('f {0} {1} {2}\n'.format(idx4+1, idx1+1, idx2+1))
                    if idx3 >= 0:
                        out.write('f {0} {1} {2}\n'.format(idx4+1, idx3+1, idx1+1))
                        out2.write('f {0} {1} {2}\n'.format(idx4+1, idx1+1, idx2+1))
                count += 1