def train(self, epochs): self.running_loss = 0.0 for epoch in range(epochs): for i, sample in enumerate(tqdm(self.dataloader, 0)): poses = sample['poses'] shapes = sample['shapes'] mano_layer = ManoLayer(mano_root='mano/models', use_pca=False, ncomps=48, flat_hand_mean=False) mano_layer.to(self.device) # Forward pass through MANO layer _, hand_joints = mano_layer(poses, shapes) uv_root = sample['uv_root'] scale = sample['scale'] hand_joints = hand_joints.reshape([self.batch_size, -1]) x = torch.cat((hand_joints, uv_root, scale), 1) x = torch.cat((x, x), 1) # print("x", x.shape) y = sample['xyz'].reshape([self.batch_size, -1]) # print("y", y.shape) # print("uv", uv_root.shape) # print("sc", scale.shape) x = x.to(self.device) y = y.to(self.device) self.optimizer.zero_grad() y_ = self.model(x) loss = self.criterion(y_, y) loss.backward() self.optimizer.step() self.running_loss += loss.item() self.g_step += 1 if self.g_step % self.save_rate == self.save_rate - 1: self.running_loss /= self.save_rate self.save_state() self.writer.add_scalar('training loss', self.running_loss / self.save_rate, self.g_step) print(self.running_loss / self.save_rate, self.g_step) self.running_loss = 0.0
class ManoDatasetC(Dataset): def __init__(self, base_path, transform, train_indices): self.transform = transform mano_path = os.path.join(base_path, '%s_mano.json' % 'training') mano_list = json_load(mano_path) mano_array = np.array(mano_list).squeeze(1) mano_poses = mano_array[..., :51] mano_poses = mano_poses[train_indices] self.kde = KernelDensity(bandwidth=0.15, kernel='gaussian') self.kde.fit(mano_poses) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.mano_layer = ManoLayer( mano_root='mano/models', use_pca=False, ncomps=45, flat_hand_mean=False) self.mano_layer.to(self.device) def __len__(self): return 32560 def __getitem__(self, idx): sample = self.kde.sample() pose = sample[..., :48] shape_start = sample[..., 48:] shape = np.ones([1, 10]) shape[..., :3] = shape_start x = { 'p': pose, 's': shape } x = self.transform(x) hand_verts, hand_joints = self.mano_layer(x['p'], x['s']) batch_size = hand_joints.shape[0] hand_joints = hand_joints.reshape([batch_size, 63]) sample = { 'hand_joints': torch.squeeze(hand_joints), 'hand_verts': torch.squeeze(hand_verts), 'poses': torch.squeeze(x['p']), 'shapes': torch.squeeze(x['s']) } return sample
def xyz_from_mano(poses, shapes): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select number of principal components for pose space ncomps = 45 # Initialize MANO layer mano_layer = ManoLayer(mano_root='mano/models', use_pca=False, ncomps=ncomps, flat_hand_mean=False) mano_layer.to(device) poses = torch.from_numpy(poses).float().to(device) shapes = torch.from_numpy(shapes).float().to(device) # Forward pass through MANO layer hand_verts, hand_joints = mano_layer(poses, shapes) return hand_verts, hand_joints
class TrainerP0(object): def __init__(self, batch_size, dataloader, model, build_id): self.batch_size = batch_size self.dataloader = dataloader self.model = model self.save_path = f'results/{build_id}.pt' self.writer = SummaryWriter(f'results/{build_id}') input_example = next(iter(dataloader))['uv'] self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.writer.add_graph(self.model, input_example) self.writer.close() self.criterion = nn.MSELoss() self.optimizer = optim.SGD(self.model.parameters(), lr=0.0001, momentum=0.9) self.load_state() ncomps = 45 self.mano_layer = ManoLayer(mano_root='mano/models', use_pca=False, ncomps=ncomps, flat_hand_mean=False) self.mano_layer.to(self.device) self.mano_layer.eval() def train(self, epochs, save_rate): self.running_loss = 0.0 for epoch in range(epochs): for i, sample in enumerate(self.dataloader, 0): uv = sample['uv'] outputs_gt = sample['mano'] uv = uv.to(self.device) outputs_gt = outputs_gt.to(self.device) self.optimizer.zero_grad() outputs = self.model(uv) * 3.12 # losses = [] # losses.append(self.criterion(outputs, outputs_gt)) loss = self.criterion(outputs, outputs_gt) # poses_pred = outputs[:, :48] # .unsqueeze(0) # shapes_pred = outputs[:, 48:] # .unsqueeze(0) # hand_verts_p, hand_joints_p = self.mano_layer(poses_pred, shapes_pred) # # poses_gt = outputs_gt[:, :48] # shapes_gt = outputs_gt[:, 48:] # hand_verts_gt, hand_joints_gt = self.mano_layer(poses_gt, shapes_gt) # losses.append(self.criterion(hand_verts_p, hand_verts_gt)) # losses.append(self.criterion(hand_joints_p, hand_joints_gt)) # loss = sum(losses) loss.backward() self.optimizer.step() self.running_loss += loss.item() self.g_step += 1 self.running_loss /= 128 self.save_state() self.writer.add_scalar('training loss', self.running_loss / 128, self.g_step) print(self.running_loss / 128, self.g_step) self.running_loss = 0.0 def save_state(self): torch.save( { 'g_step': self.g_step, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'running_loss': self.running_loss / 128, }, self.save_path) print( f'Model saved at step {self.g_step} with running loss {self.running_loss / 128}.' ) def load_state(self): if os.path.exists(self.save_path): checkpoint = torch.load(self.save_path) self.g_step = checkpoint['g_step'] + 1 self.running_loss = checkpoint['running_loss'] self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print( f'Model loaded. g_step: {self.g_step}; running_loss: {self.running_loss}' ) else: print( f'File "{self.save_path}" does not exist. Initializing parameters from scratch.' ) self.g_step = 0 self.running_loss = 0.0
CONFIG_M0['data_version'], transform=ToTensor()) dataloader = DataLoader(dataset, shuffle=True) sample = next(iter(dataloader)) random_pose = sample['poses'] random_shape = sample['shapes'] print(random_pose.shape) # Initialize MANO layer mano_layer = ManoLayer(mano_root='mano/models', use_pca=False, ncomps=48, flat_hand_mean=False) mano_layer.to(device) # Forward pass through MANO layer hand_verts, hand_joints = mano_layer(random_pose, random_shape) # demo.display_hand({ # 'verts': hand_verts, # 'joints': hand_joints # }, # mano_faces=mano_layer.th_faces) uv_root = sample['uv_root'] scale = sample['scale'] hand_joints = hand_joints.reshape([1, -1]) x = torch.cat((hand_joints, uv_root, scale), 1) x = torch.cat((x, x), 1) # NETWORK WAY
class pcl2mano(): def __init__(self, device, mano_root='mano/models', ncomps=45, seg_file='./face2label_sealed.npy', visualize_interm_result=False): self.ncomps = ncomps self.device = device self.mano_root = mano_root self.mano_layer = ManoLayer(mano_root=self.mano_root, use_pca=True, ncomps=self.ncomps, flat_hand_mean=False) self.mano_layer = self.mano_layer.to(self.device) segmentation = np.load(seg_file) self.faces = self.mano_layer.th_faces.detach().cpu() # assign mano vertex label according to face label: self.vertex_label = np.zeros((self.faces.max() + 1), dtype=np.uint8) for i in range(0, self.faces.shape[0]): self.vertex_label[self.faces[i, 0]] = segmentation[i, 0] self.vertex_label[self.faces[i, 1]] = segmentation[i, 0] self.vertex_label[self.faces[i, 2]] = segmentation[i, 0] self.label_vertex = [(np.where(self.vertex_label == 0))[0], \ (np.where(self.vertex_label == 1))[0], \ (np.where(self.vertex_label == 2))[0], \ (np.where(self.vertex_label == 3))[0], \ (np.where(self.vertex_label == 4))[0], \ (np.where(self.vertex_label == 5))[0]] self.visualize_interm_result = visualize_interm_result self.target_points = [] self.target_points_tree = [] self.no_label = [] def find_nearest_neighbour_index_from_hand(self, target_points_tree, hands, no_label): if len(hands.shape) == 3: # batch of hands closest_indices = [] for i in range(0, hands.shape[0]): hand = hands[i, :, :].detach().cpu().numpy() _, closest_palm_index = target_points_tree[0].query( hand[self.label_vertex[0], :], 1) _, closest_thum_index = target_points_tree[1].query( hand[self.label_vertex[1], :], 1) _, closest_inde_index = target_points_tree[2].query( hand[self.label_vertex[2], :], 1) _, closest_midd_index = target_points_tree[3].query( hand[self.label_vertex[3], :], 1) _, closest_ring_index = target_points_tree[4].query( hand[self.label_vertex[4], :], 1) _, closest_pink_index = target_points_tree[5].query( hand[self.label_vertex[5], :], 1) closest_indices.append([closest_palm_index, closest_thum_index, closest_inde_index,\ closest_midd_index, closest_ring_index, closest_pink_index]) return closest_indices else: # Not used print("Fatal error in find_nearest_neighbour_index_from_hand!") exit() def index2points_from_hand(self, indices, closest_points): for batch in range(0, len(indices)): # iterate over batch dimension for label in range(0, 6): # iterate over hand parts closest_points[ batch, self.label_vertex[label][:], :] = self.target_points[ label][indices[batch][label][:, 0], :] return closest_points def find_nearest_neighbour_to_hand(self, target_points, n, hands, no_label): closest_batch_hand_indices = np.zeros((0, n), dtype=np.int32) target_points_batch_rearrange = np.zeros((0, n, 3), dtype=np.float32) for i in range(0, hands.shape[0]): # iterate over batch dimension hand = hands[i, :, :].detach().cpu().numpy() closest_hand_indices = np.zeros((0), dtype=np.int32) target_points_rearrange = np.zeros((0, 3), dtype=np.float32) for hand_part in range(0, 6): # If there is no predicted label of that part, skip. if no_label[hand_part]: continue hand_part_vertex = hand[self.label_vertex[hand_part], :] hand_part_tree = KDTree(hand_part_vertex) _, closest_part_index = hand_part_tree.query( target_points[hand_part], 1) closest_hand_indices = np.concatenate( (closest_hand_indices, self.label_vertex[hand_part][closest_part_index[:, 0]]), 0) target_points_rearrange = np.concatenate( (target_points_rearrange, target_points[hand_part])) target_points_rearrange = np.expand_dims(target_points_rearrange, axis=0) target_points_batch_rearrange = np.concatenate( (target_points_batch_rearrange, target_points_rearrange), 0) closest_hand_indices = np.expand_dims(closest_hand_indices, axis=0) closest_batch_hand_indices = np.concatenate( (closest_batch_hand_indices, closest_hand_indices), 0) return closest_batch_hand_indices, target_points_batch_rearrange def index2points_to_hand(self, indices, hands): hands_batch_rearrange = torch.zeros(0, indices.shape[1], 3).float().to(self.device) for batch in range(0, indices.shape[0]): hand_rearrange = hands[batch, indices[batch, :], :] hand_rearrange = hand_rearrange.unsqueeze(0) hands_batch_rearrange = torch.cat( (hands_batch_rearrange, hand_rearrange), dim=0) # axis = 0) return hands_batch_rearrange def mask_no_label_from_hand(self, hand_verts, no_label): for i in range(len(no_label)): if no_label[i]: hand_verts[:, self.label_vertex[i], :] = 0. return hand_verts def mask_no_label_to_hand(self, hand_verts, no_label): for i in range(len(no_label)): if no_label[i]: hand_verts[:, self.label_vertex[i], :] = 0. return hand_verts def fit_mano_2_pcl( self, samples, # samples is a (N_v x 3) array, unit is Millimeter! labels, # labels is a (N_v x 1) array seeds=8, coarse_iter=50, fine_iter=50, stop_loss=5.0, verbose=0): # classify samples according to labels palm = samples[(np.where(labels == 0))[0], :] thumb = samples[(np.where(labels == 1))[0], :] index = samples[(np.where(labels == 2))[0], :] middle = samples[(np.where(labels == 3))[0], :] ring = samples[(np.where(labels == 4))[0], :] pinky = samples[(np.where(labels == 5))[0], :] for lab in [palm, thumb, index, middle, ring, pinky]: # print(len(lab)) # print(lab.shape) self.no_label.append(lab.shape[0] == 0) # Add temp point (0,0,0) to part with no sample. # The loss from these points will be masked out later when calculating loss if palm.shape[0] == 0: palm = np.zeros([1, 3]) if thumb.shape[0] == 0: thumb = np.zeros([1, 3]) if index.shape[0] == 0: index = np.zeros([1, 3]) if middle.shape[0] == 0: middle = np.zeros([1, 3]) if ring.shape[0] == 0: ring = np.zeros([1, 3]) if pinky.shape[0] == 0: pinky = np.zeros([1, 3]) # print("No label:", self.no_label) self.target_points_np = [palm, thumb, index, middle, ring, pinky] self.target_points = [torch.from_numpy(palm).float().to(self.device),\ torch.from_numpy(thumb).float().to(self.device),\ torch.from_numpy(index).float().to(self.device),\ torch.from_numpy(middle).float().to(self.device),\ torch.from_numpy(ring).float().to(self.device),\ torch.from_numpy(pinky).float().to(self.device)] self.target_points_tree = [ KDTree(palm), KDTree(thumb), KDTree(index), KDTree(middle), KDTree(ring), KDTree(pinky) ] # Model para initialization: shape = torch.zeros(seeds, 10).float().to(self.device) shape.requires_grad_() rot = torch.zeros(seeds, 3).float().to(self.device) rot.requires_grad_() pose = torch.zeros(seeds, self.ncomps).float().to(self.device) pose = (0.1 * torch.randn(seeds, self.ncomps)).float().to(self.device) pose.requires_grad_() trans = torch.from_numpy(samples.mean(0) / 1000.0) # trans should be in meter trans = trans.unsqueeze(0).repeat(seeds, 1).float().to(self.device) # trans = (0.1*torch.randn(seeds, 3)).float().to(self.device) trans.requires_grad_() hand_verts, hand_joints = self.mano_layer(torch.cat((rot, pose), 1), shape, trans) if self.visualize_interm_result: demo.display_mosh(torch.from_numpy(samples).float().unsqueeze(0).expand(seeds, -1, -1),\ np.zeros((0,4), dtype = np.int32), {'verts': hand_verts.detach().cpu(), 'joints': hand_joints.detach().cpu()}, \ mano_faces=self.mano_layer.th_faces.detach().cpu(), \ alpha = 0.3) # Global optimization criteria_loss = nn.MSELoss().to(self.device) previous_loss = 1e8 optimizer = torch.optim.Adam([trans, rot], lr=1e-2) print('...Optimizing global transformation...') for i in range(0, coarse_iter): hand_verts, hand_joints = self.mano_layer( torch.cat((rot, pose), 1), shape, trans) # Find closest label points: closest_indices = self.find_nearest_neighbour_index_from_hand( self.target_points_tree, hand_verts, self.no_label) closest_points = self.index2points_from_hand( closest_indices, torch.zeros_like(hand_verts)) for j in range(0, 20): hand_verts, hand_joints = self.mano_layer( torch.cat((rot, pose), 1), shape, trans) hand_verts = self.mask_no_label_from_hand( hand_verts, self.no_label) loss = criteria_loss(hand_verts, closest_points) optimizer.zero_grad() loss.backward() optimizer.step() loss = criteria_loss(hand_verts, closest_points) if verbose >= 1: print(i, loss.data) if previous_loss - loss.data < 1e-1: break previous_loss = loss.data.detach() # print('After coarse alignment: %6f'%(loss.data)) if self.visualize_interm_result: demo.display_mosh(torch.from_numpy(samples).float().unsqueeze(0).expand(seeds, -1, -1),\ np.zeros((0,4), dtype = np.int32), {'verts': hand_verts.detach().cpu(), 'joints': hand_joints.detach().cpu()}, \ mano_faces=self.mano_layer.th_faces.detach().cpu(), \ alpha = 0.3) # Local optimization previous_loss = 1e8 optimizer = torch.optim.Adam([trans, rot, pose, shape], lr=1e-2) print('...Optimizing hand pose shape and global transformation...') for i in range(0, fine_iter): hand_verts, hand_joints = self.mano_layer( torch.cat((rot, pose), 1), shape, trans) # Find closest label points: closest_batch_hand_indices, target_points_batch_rearrange = self.find_nearest_neighbour_to_hand( self.target_points_np, samples.shape[0], hand_verts, self.no_label) target_points_batch_rearrange = torch.from_numpy( target_points_batch_rearrange).float().to(self.device) for j in range(0, 20): hand_verts, hand_joints = self.mano_layer( torch.cat((rot, pose), 1), shape, trans) hands_batch_rearrange = self.index2points_to_hand( closest_batch_hand_indices, hand_verts) w_pose = 100.0 loss = criteria_loss( hands_batch_rearrange, target_points_batch_rearrange ) + w_pose * (pose * pose).mean() # pose regularizer optimizer.zero_grad() loss.backward() optimizer.step() loss = criteria_loss(hands_batch_rearrange, target_points_batch_rearrange) # if previous_loss - loss.data < 1e-1: # break previous_loss = loss.data.detach() # Find smallest loss in the #seeds seeds: per_seed_error = ( (hands_batch_rearrange - target_points_batch_rearrange) * (hands_batch_rearrange - target_points_batch_rearrange)).mean(2).mean(1) min_index = torch.argmin(per_seed_error).detach().cpu().numpy() min_error = per_seed_error[min_index] if verbose >= 1: print(i, min_error.data) if self.visualize_interm_result and i % 40 == 0: tmp_arange = np.expand_dims(np.arange( target_points_batch_rearrange.shape[1]), axis=1) link = np.concatenate( (tmp_arange, tmp_arange + target_points_batch_rearrange.shape[1]), 1) link = link[0:200, :] visual_points = torch.cat( (target_points_batch_rearrange[min_index, :, :], hands_batch_rearrange[min_index, :, :]), dim=0) visual_points = visual_points.unsqueeze(0).expand( seeds, -1, -1) pass demo.display_mosh(visual_points.detach().cpu(),\ link, {'verts': hand_verts.detach().cpu(), 'joints': hand_joints.detach().cpu()}, \ mano_faces=self.mano_layer.th_faces.detach().cpu(), \ alpha = 0.3) if min_error < stop_loss: break # print('After fine alignment: %6f'%(loss.data)) hand_verts, hand_joints = self.mano_layer(torch.cat((rot, pose), 1), shape, trans) hand_shape = {'vertices': hand_verts.detach().cpu().numpy()[min_index, :, :], \ 'joints': hand_joints.detach().cpu().numpy()[min_index, :, :], \ 'faces': self.mano_layer.th_faces.detach().cpu()} mano_para = {'rot': rot.detach().cpu().numpy()[min_index, :], \ 'pose': pose.detach().cpu().numpy()[min_index, :], \ 'shape': shape.detach().cpu().numpy()[min_index, :], \ 'trans': trans.detach().cpu().numpy()[min_index, :]} return hand_shape, mano_para