def generate_transform_matrices(name, refer_vertices, refer_triangles, factors): """Generates len(factors) meshes, each of them is scaled by factors[i] and computes the transformations between them. Returns: M: a set of meshes downsampled from mesh by a factor specified in factors. A: Adjacency matrix for each of the meshes D: Downsampling transforms between each of the meshes U: Upsampling transforms between each of the meshes """ factors = [1.0 / x for x in factors] # M, A, D, U = [], [], [], [] # V, T, A, D, U = [], [], [], [], [] vertices = [] triangles = [] adjacencies = [] downsamp_trans = [] upsamp_trans = [] adjacencies.append( utils.get_vert_connectivity(refer_vertices, refer_triangles)) # M.append(mesh) vertices.append(refer_vertices) triangles.append(refer_triangles) for factor in factors: ds_triangle, ds_transform = qslim_decimator_transformer(vertices[-1], triangles[-1], factor=factor) downsamp_trans.append(ds_transform) # new_mesh_v = ds_D.dot(M[-1].v) ds_vertice = ds_transform.dot(vertices[-1]) # new_mesh = Mesh(v=new_mesh_v, f=ds_f) # M.append(new_mesh) vertices.append(ds_vertice) triangles.append(ds_triangle) adjacencies.append(utils.get_vert_connectivity(ds_vertice, ds_triangle)) # U.append(setup_deformation_transfer(M[-1], M[-2])) upsamp_trans.append( setup_deformation_transfer(vertices[-1], triangles[-1], vertices[-2])) for i, (vertice, triangle) in enumerate(zip(vertices, triangles)): write_obj( os.path.join('data', 'reference', name, 'reference{}.obj'.format(i)), vertice, triangle) return adjacencies, downsamp_trans, upsamp_trans
def process(self): train_data, val_data, test_data = [], [], [] train_vertices = [] for idx, data_file in tqdm(enumerate(self.data_file)): mesh = Mesh(filename=data_file) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor(np.vstack( (adjacency.row, adjacency.col))) data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) if self.split == 'sliced': if idx % 100 <= 10: test_data.append(data) elif idx % 100 <= 20: val_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'expression': if data_file.split('/')[-2] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'identity': if data_file.split('/')[-3] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) else: raise Exception( 'sliced, expression and identity are the only supported split terms' ) if self.split != 'sliced': val_data = test_data[-self.nVal:] test_data = test_data[:-self.nVal] mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train train_data = [self.pre_transform(td) for td in train_data] val_data = [self.pre_transform(td) for td in val_data] test_data = [self.pre_transform(td) for td in test_data] torch.save(self.collate(train_data), self.processed_paths[0]) torch.save(self.collate(val_data), self.processed_paths[1]) torch.save(self.collate(test_data), self.processed_paths[2]) torch.save(norm_dict, self.processed_paths[3])
def process(self): dataset = [] vertices = [] for dx in self.filepaths: for fp in self.filepaths[dx]: mesh = Mesh(filename=fp) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor( np.vstack((adjacency.row, adjacency.col))) i = self.categories.index(dx) label = np.zeros(len(self.categories)) label[i] = 1 data = Data(x=mesh_verts, y=torch.Tensor(label), edge_index=edge_index) dataset.append(data) vertices.append(mesh.v) if self.dtype == 'train': mean_train = torch.Tensor(np.mean(vertices, axis=0)) std_train = torch.Tensor(np.std(vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_transform.mean is None: self.pre_transform.mean = mean_train if self.pre_transform.std is None: self.pre_transform.std = std_train self.data = [self.pre_transform(td) for td in dataset] torch.save(self.collate(dataset), self.processed_paths[0]) torch.save(norm_dict, self.processed_paths[2]) elif self.dtype == 'test': norm_path = self.processed_paths[2] norm_dict = torch.load(norm_path) mean_train, std_train = norm_dict['mean'], norm_dict['std'] if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_transform.mean is None: self.pre_transform.mean = mean_train if self.pre_transform.std is None: self.pre_transform.std = std_train self.data = [self.pre_transform(td) for td in dataset] torch.save(self.collate(dataset), self.processed_paths[1])
def process(self): print('Computing mean and std ...') train_vertices = [] for idx, data_file in tqdm(enumerate(self.datafile_train)): mesh = Mesh(filename=data_file) train_vertices.append(mesh.v) mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} dest_path = osp.join(self.processed_dir, self.split_term) if not os.path.exists(dest_path): os.makedirs(dest_path) torch.save(norm_dict, self.processed_paths[0]) if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr(self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train subsets = ['train', 'test', 'val'] for subset in subsets: print('processing {} ...'.format(subset)) dest_path = osp.join(self.processed_dir, self.split_term, '{}'.format(subset)) if not os.path.exists(dest_path): os.makedirs(dest_path) vertices = [] for idx, data_file in tqdm(enumerate(eval('self.datafile_'+subset))): mesh = Mesh(filename=data_file) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.LongTensor(np.vstack((adjacency.row, adjacency.col))) # edge_index = torch.Tensor(np.vstack((adjacency.row, adjacency.col))) data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) vertices.append(mesh.v) if self.pre_transform is not None: data = self.pre_transform(data) torch.save(data, osp.join(self.processed_dir, self.split_term, '{}/data_{}.pt'.format(subset,idx)))
def process(self): train_data, val_data, test_data = [], [], [] train_vertices = [] for dx in self.data_file : for tt in self.data_file[dx] : for f in tqdm(self.data_file[dx][tt]): mesh = Mesh(filename=f) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor(np.vstack((adjacency.row, adjacency.col))) if dx == 'ad' : data = Data(x=mesh_verts, y=torch.Tensor([1,0]), edge_index=edge_index) elif dx == 'cn' : data = Data(x=mesh_verts, y=torch.Tensor([0,1]), edge_index=edge_index) if tt == 'test' : test_data.append(data) elif tt == 'train' : train_data.append(data) train_vertices.append(mesh.v) mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr(self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train train_data = [self.pre_transform(td) for td in train_data] val_data = [self.pre_transform(td) for td in val_data] test_data = [self.pre_transform(td) for td in test_data] torch.save(self.collate(train_data), self.processed_paths[0]) #torch.save(self.collate(val_data), self.processed_paths[1]) torch.save(self.collate(test_data), self.processed_paths[1]) torch.save(norm_dict, self.processed_paths[2])
def process(self): train_data, val_data, test_data = [], [], [] train_vertices = [] for key in self.data_file: for idx, data_file in tqdm(enumerate(self.data_file[key])): if key == 'lgtd': mesh = Mesh(filename=data_file[0]) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor( np.vstack((adjacency.row, adjacency.col))) mesh_m24 = Mesh(filename=data_file[1]) data = Data(x=mesh_verts, y=torch.Tensor(mesh_m24.v), edge_index=edge_index) elif key == 'lgtddx': mesh = Mesh(filename=data_file[0]) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor( np.vstack((adjacency.row, adjacency.col))) mesh_m24 = Mesh(filename=data_file[1]) data = Data(x=mesh_verts, y=torch.Tensor(mesh_m24.v), edge_index=edge_index, label=torch.Tensor(data_file[2])) elif key == 'lgtdvc': #print(data_file) mesh = Mesh(filename=data_file[0]) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor( np.vstack((adjacency.row, adjacency.col))) mesh_fu = Mesh(filename=data_file[1]) data = Data(x=mesh_verts, y=torch.Tensor(mesh_fu.v), edge_index=edge_index, label=torch.Tensor(data_file[2]), period=torch.Tensor(data_file[3])) else: mesh = Mesh(filename=data_file) mesh_verts = torch.Tensor(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor( np.vstack((adjacency.row, adjacency.col))) if key == 'ad': data = Data(x=mesh_verts, y=mesh_verts, label=torch.Tensor([1, 0]), edge_index=edge_index) elif key == 'cn': data = Data(x=mesh_verts, y=mesh_verts, label=torch.Tensor([0, 1]), edge_index=edge_index) elif key == 'all': data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) if idx % 100 < 10: test_data.append(data) #print(data.period) elif idx % 100 < 20: val_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) print(len(train_data), len(val_data), len(test_data)) mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train train_data = [self.pre_transform(td) for td in train_data] val_data = [self.pre_transform(td) for td in val_data] test_data = [self.pre_transform(td) for td in test_data] torch.save(self.collate(train_data), self.processed_paths[0]) torch.save(self.collate(val_data), self.processed_paths[1]) torch.save(self.collate(test_data), self.processed_paths[2]) torch.save(norm_dict, self.processed_paths[3])
def process(self): train_data, val_data, test_data = [], [], [] train_vertices = [] count = 0 for idx, data_file in tqdm(enumerate(self.data_file)): fn = data_file.split('/') #print(fn[2]) mesh = Mesh(filename=data_file) mesh_verts = torch.Tensor(mesh.v) # print(torch.mean(mesh_verts, dim = 0, keepdim = True)) #mesh_verts = mesh_verts-torch.mean(mesh_verts, dim = 0, keepdim = True) ''' if idx % 100 == 0: print(torch.mean(mesh_verts, dim=0)) save_obj(mesh_verts, mesh.f+1, os.path.join('./vis/', 'data'+str(idx)+'.obj')) ''' adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor(np.vstack( (adjacency.row, adjacency.col))).long() data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) if self.split == 'sliced': if idx % 100 <= 10: test_data.append(data) # elif idx % 100 <= 10: # val_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'expression': if data_file.split('/')[-2] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'identity': if data_file.split('/')[-3] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) else: raise Exception( 'sliced, expression and identity are the only supported split terms' ) if self.split != 'sliced': val_data = test_data[-self.nVal:] test_data = test_data[:-self.nVal] mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) print(mean_train.shape) return std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train train_data = [self.pre_transform(td) for td in train_data] val_data = [self.pre_transform(td) for td in val_data] test_data = [self.pre_transform(td) for td in test_data] torch.save(self.collate(train_data), self.processed_paths[0]) torch.save(self.collate(test_data), self.processed_paths[1]) torch.save(self.collate(test_data), self.processed_paths[2]) torch.save(norm_dict, self.processed_paths[3])
def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) if args.checkpoint_dir: checkpoint_dir = args.checkpoint_dir else: checkpoint_dir = config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) visualize = config['visualize'] output_dir = config['visual_output_dir'] if visualize is True and not output_dir: print( 'No visual output directory is provided. Checkpoint directory will be used to store the visual results' ) output_dir = checkpoint_dir if not os.path.exists(output_dir): os.makedirs(output_dir) eval_flag = config['eval'] lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config['workers_thread'] opt = config['optimizer'] batch_size = config['batch_size'] val_losses, accs, durations = [], [], [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, config['downsampling_factors']) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Loading Dataset') if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) print('Loading model') start_epoch = 1 coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes) if opt == 'adam': optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') checkpoint_file = config['checkpoint_file'] print(checkpoint_file) if checkpoint_file: checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch_num'] coma.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #To find if this is fixed in pytorch for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) coma.to(device) print('making...') norm = torch.load('../processed_data/processed/sliced_norm.pt') normalize_transform.mean = norm['mean'] normalize_transform.std = norm['std'] #'0512','0901','0516','0509','0507','9305','0503','4919','4902', files = [ '0514', '0503', '0507', '0509', '0512', '0501', '0901', '1001', '4902', '4913', '4919', '9302', '9305', '12411' ] coma.eval() meshviewer = MeshViewers(shape=(1, 2)) for file in files: #mat = np.load('../Dress Dataset/'+file+'/'+file+'_pose.npz') mesh_dir = os.listdir('../processed_data/' + file + '/mesh/') latent = [] print(len(mesh_dir)) for i in tqdm(range(len(mesh_dir))): data_file = '../processed_data/' + file + '/mesh/' + str( i) + '.obj' mesh = Mesh(filename=data_file) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor(np.vstack( (adjacency.row, adjacency.col))).long() mesh_verts = (torch.Tensor(mesh.v) - normalize_transform.mean) / normalize_transform.std data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) data = data.to(device) with torch.no_grad(): out, feature = coma(data) latent.append(feature.cpu().detach().numpy()) # print(feature.shape) if i % 50 == 0: expected_out = data.x out = out.cpu().detach( ) * normalize_transform.std + normalize_transform.mean expected_out = expected_out.cpu().detach( ) * normalize_transform.std + normalize_transform.mean out = out.numpy() save_obj(out, template_mesh.f + 1, './vis/reconstruct_' + str(i) + '.obj') save_obj(expected_out, template_mesh.f + 1, './vis/ori_' + str(i) + '.obj') np.save('./processed/0820/' + file, latent) if torch.cuda.is_available(): torch.cuda.synchronize()
def process(self, transform=None): train_data, val_data, test_data = [], [], [] train_vertices = [] for idx, data_file in tqdm(enumerate(self.data_file)): # print('\n\n' + data_file + '\n\n') mesh = Mesh(filename=data_file) mesh_verts = torch.Tensor(mesh.v) # print(mesh.v) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() # print(adjacency.row) edge_index = torch.Tensor(np.vstack( (adjacency.row, adjacency.col))).type(torch.LongTensor) # print(edge_index) # taken from pyG (transform.NormalizedScale()) mesh_verts = mesh_verts - mesh_verts.mean(dim=-2, keepdim=True) scale = (1 / mesh_verts.abs().max()) * 0.999999 mesh_verts = mesh_verts * scale data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) # taken from pyG (transform.NormalizedScale()) # data.x = data.x - data.x.mean(dim=-2, keepdim=True) # scale = (1 / data.x.abs().max()) * 0.999999 # data.x = data.x * scale if self.split == 'sliced': if idx % 100 <= 10: test_data.append(data) elif idx % 100 <= 20: val_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'expression': if data_file.split('/')[-2] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) elif self.split == 'identity': if data_file.split('/')[-3] == self.split_term: test_data.append(data) else: train_data.append(data) train_vertices.append(mesh.v) else: raise Exception( 'sliced, expression and identity are the only supported split terms' ) if self.split != 'sliced': val_data = test_data[-self.nVal:] test_data = test_data[:-self.nVal] mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) std_train = torch.Tensor(np.std(train_vertices, axis=0)) norm_dict = {'mean': mean_train, 'std': std_train} if self.pre_transform is not None: if hasattr(self.pre_transform, 'mean') and hasattr( self.pre_transform, 'std'): if self.pre_tranform.mean is None: self.pre_tranform.mean = mean_train if self.pre_transform.std is None: self.pre_tranform.std = std_train train_data = [self.pre_transform(td) for td in train_data] val_data = [self.pre_transform(td) for td in val_data] test_data = [self.pre_transform(td) for td in test_data] print('train_data[0].x : {}'.format(train_data[0].x)) print('\n\n') print('train_data[0].x : {}'.format(train_data[1].x)) torch.save(self.collate(train_data), self.processed_paths[0]) torch.save(self.collate(val_data), self.processed_paths[1]) torch.save(self.collate(test_data), self.processed_paths[2]) torch.save(norm_dict, self.processed_paths[3])