def spiral_tramsform(transform_fp, template_fp, ds_factors, seq_length, dilation): if not osp.exists(transform_fp): print('Generating transform matrices...') mesh = Mesh(filename=template_fp) # ds_factors = [3.5, 3.5, 3.5, 3.5] _, A, D, U, F, V = mesh_sampling.generate_transform_matrices( mesh, ds_factors) tmp = { 'vertices': V, 'face': F, 'adj': A, 'down_transform': D, 'up_transform': U } with open(transform_fp, 'wb') as fp: pickle.dump(tmp, fp) print('Done!') print('Transform matrices are saved in \'{}\''.format(transform_fp)) else: with open(transform_fp, 'rb') as f: tmp = pickle.load(f, encoding='latin1') spiral_indices_list = [ utils.preprocess_spiral(tmp['face'][idx], seq_length[idx], tmp['vertices'][idx], dilation[idx]) #.to(device) for idx in range(len(tmp['face']) - 1) ] down_transform_list = [ utils.to_sparse(down_transform) #.to(device) for down_transform in tmp['down_transform'] ] up_transform_list = [ utils.to_sparse(up_transform) #.to(device) for up_transform in tmp['up_transform'] ] return spiral_indices_list, down_transform_list, up_transform_list, tmp
meshdata = MeshData(args.data_fp, template_fp, split=args.split, test_exp=args.test_exp) train_loader = DataLoader(meshdata.train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size) # generate/load transform matrices transform_fp = osp.join(args.data_fp, 'transform.pkl') if not osp.exists(transform_fp): print('Generating transform matrices...') mesh = Mesh(filename=template_fp) ds_factors = [4, 4, 4, 4] _, A, D, U, F = mesh_sampling.generate_transform_matrices(mesh, ds_factors) tmp = {'face': F, 'adj': A, 'down_transform': D, 'up_transform': U} with open(transform_fp, 'wb') as fp: pickle.dump(tmp, fp) print('Done!') print('Transform matrices are saved in \'{}\''.format(transform_fp)) else: with open(transform_fp, 'rb') as f: tmp = pickle.load(f, encoding='latin1') edge_index_list = [utils.to_edge_index(adj).to(device) for adj in tmp['adj']] down_transform_list = [ utils.to_sparse(down_transform).to(device) for down_transform in tmp['down_transform'] ]
def process_mesh(opt, body_part, device): ref_mesh = constants.REF_MESH[body_part] ds_factors = constants.DS_FACTORS[body_part] reference_points = constants.REFERENCE_POINTS[body_part] step_sizes = constants.STEP_SIZES[body_part] dilation = constants.DILATION_VAL[body_part] shapedata = ShapeData(train_file=osp.join(opt.processed_verts_dir, body_part, 'preprocessed/train_verts.npy'), val_file=osp.join(opt.processed_verts_dir, body_part, 'preprocessed/val_verts.npy'), test_file=osp.join(opt.processed_verts_dir, body_part, 'preprocessed/test_verts.npy'), reference_mesh_file=ref_mesh, normalization=opt.dataset.normalization, meshpackage=opt.dataset.meshpackage, load_flag=True) pkl_path = osp.join(constants.MESH_DATA_DIR, body_part, '{}.pkl'.format(ds_factors)) print(pkl_path) if not osp.exists(pkl_path): if shapedata.meshpackage == 'trimesh': raise NotImplementedError('Rerun with mpi-mesh as meshpackage') print("Generating Transform Matrices ..") if opt.dataset.downsample_method == 'COMA_downsample': M, A, D, U, F = generate_transform_matrices( shapedata.reference_mesh, ds_factors) with open(pkl_path, 'wb') as fp: M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))] pickle.dump( { 'M_verts_faces': M_verts_faces, 'A': A, 'D': D, 'U': U, 'F': F }, fp) else: raise NotImplementedError('Rerun with COMA_downsample') else: print("Loading Transform Matrices ..") with open(pkl_path, 'rb') as fp: downsampling_matrices = pickle.load( fp, encoding='latin1') # for python3, need to add encoding M_verts_faces = downsampling_matrices['M_verts_faces'] if shapedata.meshpackage == 'mpi-mesh': M = [ Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces)) ] elif shapedata.meshpackage == 'trimesh': M = [ trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process=False) for i in range(len(M_verts_faces)) ] A = downsampling_matrices['A'] D = downsampling_matrices['D'] U = downsampling_matrices['U'] F = downsampling_matrices['F'] print("Calculating reference points for downsampled versions ..") for i in range(len(ds_factors)): if shapedata.meshpackage == 'mpi-mesh': dist = euclidean_distances(M[i + 1].v, M[0].v[reference_points[0]]) elif shapedata.meshpackage == 'trimesh': dist = euclidean_distances(M[i + 1].vertices, M[0].vertices[reference_points[0]]) reference_points.append(np.argmin(dist, axis=0).tolist()) if shapedata.meshpackage == 'mpi-mesh': vnum = [x.v.shape[0] for x in M] elif shapedata.meshpackage == 'trimesh': vnum = [x.vertices.shape[0] for x in M] Adj, Trigs = get_adj_trigs(A, F, shapedata.reference_mesh, meshpackage=shapedata.meshpackage) print("Generating Spirals ..") spirals_np, spiral_sizes, spirals = generate_spirals( step_sizes, M, Adj, Trigs, reference_points=reference_points, dilation=dilation, random=False, meshpackage=shapedata.meshpackage, counter_clockwise=True) bU = [] bD = [] for i in range(len(D)): d = np.zeros((1, D[i].shape[0] + 1, D[i].shape[1] + 1)) u = np.zeros((1, U[i].shape[0] + 1, U[i].shape[1] + 1)) d[0, :-1, :-1] = D[i].todense() u[0, :-1, :-1] = U[i].todense() d[0, -1, -1] = 1 u[0, -1, -1] = 1 bD.append(d) bU.append(u) tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np] tD = [torch.from_numpy(s).float().to(device) for s in bD] tU = [torch.from_numpy(s).float().to(device) for s in bU] return shapedata, tspirals, spiral_sizes, vnum, tD, tU