Ejemplo n.º 1
0
def spiral_tramsform(transform_fp, template_fp, ds_factors, seq_length,
                     dilation):
    if not osp.exists(transform_fp):
        print('Generating transform matrices...')
        mesh = Mesh(filename=template_fp)
        # ds_factors = [3.5, 3.5, 3.5, 3.5]
        _, A, D, U, F, V = mesh_sampling.generate_transform_matrices(
            mesh, ds_factors)
        tmp = {
            'vertices': V,
            'face': F,
            'adj': A,
            'down_transform': D,
            'up_transform': U
        }

        with open(transform_fp, 'wb') as fp:
            pickle.dump(tmp, fp)
        print('Done!')
        print('Transform matrices are saved in \'{}\''.format(transform_fp))
    else:
        with open(transform_fp, 'rb') as f:
            tmp = pickle.load(f, encoding='latin1')

    spiral_indices_list = [
        utils.preprocess_spiral(tmp['face'][idx], seq_length[idx],
                                tmp['vertices'][idx],
                                dilation[idx])  #.to(device)
        for idx in range(len(tmp['face']) - 1)
    ]

    down_transform_list = [
        utils.to_sparse(down_transform)  #.to(device)
        for down_transform in tmp['down_transform']
    ]
    up_transform_list = [
        utils.to_sparse(up_transform)  #.to(device)
        for up_transform in tmp['up_transform']
    ]

    return spiral_indices_list, down_transform_list, up_transform_list, tmp
Ejemplo n.º 2
0
meshdata = MeshData(args.data_fp,
                    template_fp,
                    split=args.split,
                    test_exp=args.test_exp)
train_loader = DataLoader(meshdata.train_dataset,
                          batch_size=args.batch_size,
                          shuffle=True)
test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size)

# generate/load transform matrices
transform_fp = osp.join(args.data_fp, 'transform.pkl')
if not osp.exists(transform_fp):
    print('Generating transform matrices...')
    mesh = Mesh(filename=template_fp)
    ds_factors = [4, 4, 4, 4]
    _, A, D, U, F = mesh_sampling.generate_transform_matrices(mesh, ds_factors)
    tmp = {'face': F, 'adj': A, 'down_transform': D, 'up_transform': U}

    with open(transform_fp, 'wb') as fp:
        pickle.dump(tmp, fp)
    print('Done!')
    print('Transform matrices are saved in \'{}\''.format(transform_fp))
else:
    with open(transform_fp, 'rb') as f:
        tmp = pickle.load(f, encoding='latin1')

edge_index_list = [utils.to_edge_index(adj).to(device) for adj in tmp['adj']]
down_transform_list = [
    utils.to_sparse(down_transform).to(device)
    for down_transform in tmp['down_transform']
]
Ejemplo n.º 3
0
def process_mesh(opt, body_part, device):
    ref_mesh = constants.REF_MESH[body_part]
    ds_factors = constants.DS_FACTORS[body_part]
    reference_points = constants.REFERENCE_POINTS[body_part]
    step_sizes = constants.STEP_SIZES[body_part]
    dilation = constants.DILATION_VAL[body_part]
    shapedata = ShapeData(train_file=osp.join(opt.processed_verts_dir,
                                              body_part,
                                              'preprocessed/train_verts.npy'),
                          val_file=osp.join(opt.processed_verts_dir, body_part,
                                            'preprocessed/val_verts.npy'),
                          test_file=osp.join(opt.processed_verts_dir,
                                             body_part,
                                             'preprocessed/test_verts.npy'),
                          reference_mesh_file=ref_mesh,
                          normalization=opt.dataset.normalization,
                          meshpackage=opt.dataset.meshpackage,
                          load_flag=True)

    pkl_path = osp.join(constants.MESH_DATA_DIR, body_part,
                        '{}.pkl'.format(ds_factors))
    print(pkl_path)
    if not osp.exists(pkl_path):
        if shapedata.meshpackage == 'trimesh':
            raise NotImplementedError('Rerun with mpi-mesh as meshpackage')

        print("Generating Transform Matrices ..")
        if opt.dataset.downsample_method == 'COMA_downsample':
            M, A, D, U, F = generate_transform_matrices(
                shapedata.reference_mesh, ds_factors)
            with open(pkl_path, 'wb') as fp:
                M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
                pickle.dump(
                    {
                        'M_verts_faces': M_verts_faces,
                        'A': A,
                        'D': D,
                        'U': U,
                        'F': F
                    }, fp)
        else:
            raise NotImplementedError('Rerun with COMA_downsample')

    else:
        print("Loading Transform Matrices ..")
        with open(pkl_path, 'rb') as fp:
            downsampling_matrices = pickle.load(
                fp, encoding='latin1')  # for python3, need to add encoding

        M_verts_faces = downsampling_matrices['M_verts_faces']
        if shapedata.meshpackage == 'mpi-mesh':
            M = [
                Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1])
                for i in range(len(M_verts_faces))
            ]
        elif shapedata.meshpackage == 'trimesh':
            M = [
                trimesh.base.Trimesh(vertices=M_verts_faces[i][0],
                                     faces=M_verts_faces[i][1],
                                     process=False)
                for i in range(len(M_verts_faces))
            ]
        A = downsampling_matrices['A']
        D = downsampling_matrices['D']
        U = downsampling_matrices['U']
        F = downsampling_matrices['F']

    print("Calculating reference points for downsampled versions ..")
    for i in range(len(ds_factors)):
        if shapedata.meshpackage == 'mpi-mesh':
            dist = euclidean_distances(M[i + 1].v, M[0].v[reference_points[0]])
        elif shapedata.meshpackage == 'trimesh':
            dist = euclidean_distances(M[i + 1].vertices,
                                       M[0].vertices[reference_points[0]])
        reference_points.append(np.argmin(dist, axis=0).tolist())

    if shapedata.meshpackage == 'mpi-mesh':
        vnum = [x.v.shape[0] for x in M]
    elif shapedata.meshpackage == 'trimesh':
        vnum = [x.vertices.shape[0] for x in M]
    Adj, Trigs = get_adj_trigs(A,
                               F,
                               shapedata.reference_mesh,
                               meshpackage=shapedata.meshpackage)

    print("Generating Spirals ..")
    spirals_np, spiral_sizes, spirals = generate_spirals(
        step_sizes,
        M,
        Adj,
        Trigs,
        reference_points=reference_points,
        dilation=dilation,
        random=False,
        meshpackage=shapedata.meshpackage,
        counter_clockwise=True)
    bU = []
    bD = []
    for i in range(len(D)):
        d = np.zeros((1, D[i].shape[0] + 1, D[i].shape[1] + 1))
        u = np.zeros((1, U[i].shape[0] + 1, U[i].shape[1] + 1))
        d[0, :-1, :-1] = D[i].todense()
        u[0, :-1, :-1] = U[i].todense()
        d[0, -1, -1] = 1
        u[0, -1, -1] = 1
        bD.append(d)
        bU.append(u)

    tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
    tD = [torch.from_numpy(s).float().to(device) for s in bD]
    tU = [torch.from_numpy(s).float().to(device) for s in bU]
    return shapedata, tspirals, spiral_sizes, vnum, tD, tU