Exemple #1
0
    def train(self, epochs):
        self.running_loss = 0.0

        for epoch in range(epochs):
            for i, sample in enumerate(tqdm(self.dataloader, 0)):
                poses = sample['poses']
                shapes = sample['shapes']

                mano_layer = ManoLayer(mano_root='mano/models',
                                       use_pca=False,
                                       ncomps=48,
                                       flat_hand_mean=False)

                mano_layer.to(self.device)
                # Forward pass through MANO layer
                _, hand_joints = mano_layer(poses, shapes)

                uv_root = sample['uv_root']
                scale = sample['scale']

                hand_joints = hand_joints.reshape([self.batch_size, -1])
                x = torch.cat((hand_joints, uv_root, scale), 1)
                x = torch.cat((x, x), 1)
                # print("x", x.shape)
                y = sample['xyz'].reshape([self.batch_size, -1])
                # print("y", y.shape)

                # print("uv", uv_root.shape)
                # print("sc", scale.shape)

                x = x.to(self.device)
                y = y.to(self.device)

                self.optimizer.zero_grad()

                y_ = self.model(x)

                loss = self.criterion(y_, y)

                loss.backward()
                self.optimizer.step()

                self.running_loss += loss.item()
                self.g_step += 1

                if self.g_step % self.save_rate == self.save_rate - 1:
                    self.running_loss /= self.save_rate
                    self.save_state()
                    self.writer.add_scalar('training loss',
                                           self.running_loss / self.save_rate,
                                           self.g_step)
                    print(self.running_loss / self.save_rate, self.g_step)
                    self.running_loss = 0.0
class ManoDatasetC(Dataset):
    def __init__(self, base_path, transform, train_indices):
        self.transform = transform

        mano_path = os.path.join(base_path, '%s_mano.json' % 'training')
        mano_list = json_load(mano_path)
        mano_array = np.array(mano_list).squeeze(1)
        mano_poses = mano_array[..., :51]

        mano_poses = mano_poses[train_indices]

        self.kde = KernelDensity(bandwidth=0.15, kernel='gaussian')
        self.kde.fit(mano_poses)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.mano_layer = ManoLayer(
            mano_root='mano/models', use_pca=False, ncomps=45, flat_hand_mean=False)

        self.mano_layer.to(self.device)

    def __len__(self):
        return 32560

    def __getitem__(self, idx):
        sample = self.kde.sample()
        pose = sample[..., :48]
        shape_start = sample[..., 48:]
        shape = np.ones([1, 10])
        shape[..., :3] = shape_start

        x = {
            'p': pose,
            's': shape
        }
        x = self.transform(x)

        hand_verts, hand_joints = self.mano_layer(x['p'], x['s'])
        batch_size = hand_joints.shape[0]
        hand_joints = hand_joints.reshape([batch_size, 63])

        sample = {
            'hand_joints': torch.squeeze(hand_joints),
            'hand_verts': torch.squeeze(hand_verts),
            'poses': torch.squeeze(x['p']),
            'shapes': torch.squeeze(x['s'])
        }

        return sample
Exemple #3
0
def xyz_from_mano(poses, shapes):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Select number of principal components for pose space
    ncomps = 45

    # Initialize MANO layer
    mano_layer = ManoLayer(mano_root='mano/models',
                           use_pca=False,
                           ncomps=ncomps,
                           flat_hand_mean=False)
    mano_layer.to(device)

    poses = torch.from_numpy(poses).float().to(device)
    shapes = torch.from_numpy(shapes).float().to(device)

    # Forward pass through MANO layer
    hand_verts, hand_joints = mano_layer(poses, shapes)

    return hand_verts, hand_joints
Exemple #4
0
class TrainerP0(object):
    def __init__(self, batch_size, dataloader, model, build_id):
        self.batch_size = batch_size
        self.dataloader = dataloader
        self.model = model

        self.save_path = f'results/{build_id}.pt'

        self.writer = SummaryWriter(f'results/{build_id}')
        input_example = next(iter(dataloader))['uv']

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.writer.add_graph(self.model, input_example)
        self.writer.close()

        self.criterion = nn.MSELoss()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=0.0001,
                                   momentum=0.9)

        self.load_state()

        ncomps = 45
        self.mano_layer = ManoLayer(mano_root='mano/models',
                                    use_pca=False,
                                    ncomps=ncomps,
                                    flat_hand_mean=False)

        self.mano_layer.to(self.device)
        self.mano_layer.eval()

    def train(self, epochs, save_rate):
        self.running_loss = 0.0

        for epoch in range(epochs):
            for i, sample in enumerate(self.dataloader, 0):
                uv = sample['uv']
                outputs_gt = sample['mano']

                uv = uv.to(self.device)
                outputs_gt = outputs_gt.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.model(uv) * 3.12

                # losses = []
                # losses.append(self.criterion(outputs, outputs_gt))

                loss = self.criterion(outputs, outputs_gt)

                # poses_pred = outputs[:, :48]  # .unsqueeze(0)
                # shapes_pred = outputs[:, 48:]  # .unsqueeze(0)
                # hand_verts_p, hand_joints_p = self.mano_layer(poses_pred, shapes_pred)
                #
                # poses_gt = outputs_gt[:, :48]
                # shapes_gt = outputs_gt[:, 48:]
                # hand_verts_gt, hand_joints_gt = self.mano_layer(poses_gt, shapes_gt)

                # losses.append(self.criterion(hand_verts_p, hand_verts_gt))
                # losses.append(self.criterion(hand_joints_p, hand_joints_gt))

                # loss = sum(losses)

                loss.backward()
                self.optimizer.step()

                self.running_loss += loss.item()
                self.g_step += 1

            self.running_loss /= 128
            self.save_state()
            self.writer.add_scalar('training loss', self.running_loss / 128,
                                   self.g_step)
            print(self.running_loss / 128, self.g_step)
            self.running_loss = 0.0

    def save_state(self):
        torch.save(
            {
                'g_step': self.g_step,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'running_loss': self.running_loss / 128,
            }, self.save_path)

        print(
            f'Model saved at step {self.g_step} with running loss {self.running_loss / 128}.'
        )

    def load_state(self):
        if os.path.exists(self.save_path):
            checkpoint = torch.load(self.save_path)
            self.g_step = checkpoint['g_step'] + 1
            self.running_loss = checkpoint['running_loss']

            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(
                f'Model loaded. g_step: {self.g_step}; running_loss: {self.running_loss}'
            )
        else:
            print(
                f'File "{self.save_path}" does not exist. Initializing parameters from scratch.'
            )
            self.g_step = 0
            self.running_loss = 0.0
Exemple #5
0
                          CONFIG_M0['data_version'],
                          transform=ToTensor())

dataloader = DataLoader(dataset, shuffle=True)
sample = next(iter(dataloader))
random_pose = sample['poses']
random_shape = sample['shapes']

print(random_pose.shape)
# Initialize MANO layer
mano_layer = ManoLayer(mano_root='mano/models',
                       use_pca=False,
                       ncomps=48,
                       flat_hand_mean=False)

mano_layer.to(device)
# Forward pass through MANO layer
hand_verts, hand_joints = mano_layer(random_pose, random_shape)
# demo.display_hand({
#     'verts': hand_verts,
#     'joints': hand_joints
# },
#     mano_faces=mano_layer.th_faces)

uv_root = sample['uv_root']
scale = sample['scale']
hand_joints = hand_joints.reshape([1, -1])
x = torch.cat((hand_joints, uv_root, scale), 1)
x = torch.cat((x, x), 1)

# NETWORK WAY
Exemple #6
0
class pcl2mano():
    def __init__(self,
                 device,
                 mano_root='mano/models',
                 ncomps=45,
                 seg_file='./face2label_sealed.npy',
                 visualize_interm_result=False):
        self.ncomps = ncomps
        self.device = device
        self.mano_root = mano_root
        self.mano_layer = ManoLayer(mano_root=self.mano_root,
                                    use_pca=True,
                                    ncomps=self.ncomps,
                                    flat_hand_mean=False)
        self.mano_layer = self.mano_layer.to(self.device)

        segmentation = np.load(seg_file)
        self.faces = self.mano_layer.th_faces.detach().cpu()
        # assign mano vertex label according to face label:
        self.vertex_label = np.zeros((self.faces.max() + 1), dtype=np.uint8)

        for i in range(0, self.faces.shape[0]):
            self.vertex_label[self.faces[i, 0]] = segmentation[i, 0]
            self.vertex_label[self.faces[i, 1]] = segmentation[i, 0]
            self.vertex_label[self.faces[i, 2]] = segmentation[i, 0]

        self.label_vertex = [(np.where(self.vertex_label == 0))[0], \
                            (np.where(self.vertex_label == 1))[0], \
                            (np.where(self.vertex_label == 2))[0], \
                            (np.where(self.vertex_label == 3))[0], \
                            (np.where(self.vertex_label == 4))[0], \
                            (np.where(self.vertex_label == 5))[0]]

        self.visualize_interm_result = visualize_interm_result

        self.target_points = []
        self.target_points_tree = []
        self.no_label = []

    def find_nearest_neighbour_index_from_hand(self, target_points_tree, hands,
                                               no_label):
        if len(hands.shape) == 3:  # batch of hands
            closest_indices = []
            for i in range(0, hands.shape[0]):
                hand = hands[i, :, :].detach().cpu().numpy()
                _, closest_palm_index = target_points_tree[0].query(
                    hand[self.label_vertex[0], :], 1)
                _, closest_thum_index = target_points_tree[1].query(
                    hand[self.label_vertex[1], :], 1)
                _, closest_inde_index = target_points_tree[2].query(
                    hand[self.label_vertex[2], :], 1)
                _, closest_midd_index = target_points_tree[3].query(
                    hand[self.label_vertex[3], :], 1)
                _, closest_ring_index = target_points_tree[4].query(
                    hand[self.label_vertex[4], :], 1)
                _, closest_pink_index = target_points_tree[5].query(
                    hand[self.label_vertex[5], :], 1)

                closest_indices.append([closest_palm_index, closest_thum_index, closest_inde_index,\
                    closest_midd_index, closest_ring_index, closest_pink_index])

            return closest_indices
        else:  # Not used
            print("Fatal error in find_nearest_neighbour_index_from_hand!")
            exit()

    def index2points_from_hand(self, indices, closest_points):
        for batch in range(0, len(indices)):  # iterate over batch dimension
            for label in range(0, 6):  # iterate over hand parts
                closest_points[
                    batch,
                    self.label_vertex[label][:], :] = self.target_points[
                        label][indices[batch][label][:, 0], :]

        return closest_points

    def find_nearest_neighbour_to_hand(self, target_points, n, hands,
                                       no_label):
        closest_batch_hand_indices = np.zeros((0, n), dtype=np.int32)
        target_points_batch_rearrange = np.zeros((0, n, 3), dtype=np.float32)
        for i in range(0, hands.shape[0]):  # iterate over batch dimension
            hand = hands[i, :, :].detach().cpu().numpy()
            closest_hand_indices = np.zeros((0), dtype=np.int32)
            target_points_rearrange = np.zeros((0, 3), dtype=np.float32)
            for hand_part in range(0, 6):
                # If there is no predicted label of that part, skip.
                if no_label[hand_part]:
                    continue
                hand_part_vertex = hand[self.label_vertex[hand_part], :]
                hand_part_tree = KDTree(hand_part_vertex)
                _, closest_part_index = hand_part_tree.query(
                    target_points[hand_part], 1)
                closest_hand_indices = np.concatenate(
                    (closest_hand_indices,
                     self.label_vertex[hand_part][closest_part_index[:, 0]]),
                    0)
                target_points_rearrange = np.concatenate(
                    (target_points_rearrange, target_points[hand_part]))
            target_points_rearrange = np.expand_dims(target_points_rearrange,
                                                     axis=0)

            target_points_batch_rearrange = np.concatenate(
                (target_points_batch_rearrange, target_points_rearrange), 0)
            closest_hand_indices = np.expand_dims(closest_hand_indices, axis=0)
            closest_batch_hand_indices = np.concatenate(
                (closest_batch_hand_indices, closest_hand_indices), 0)

        return closest_batch_hand_indices, target_points_batch_rearrange

    def index2points_to_hand(self, indices, hands):
        hands_batch_rearrange = torch.zeros(0, indices.shape[1],
                                            3).float().to(self.device)
        for batch in range(0, indices.shape[0]):
            hand_rearrange = hands[batch, indices[batch, :], :]
            hand_rearrange = hand_rearrange.unsqueeze(0)
            hands_batch_rearrange = torch.cat(
                (hands_batch_rearrange, hand_rearrange), dim=0)  # axis = 0)
        return hands_batch_rearrange

    def mask_no_label_from_hand(self, hand_verts, no_label):
        for i in range(len(no_label)):
            if no_label[i]:
                hand_verts[:, self.label_vertex[i], :] = 0.
        return hand_verts

    def mask_no_label_to_hand(self, hand_verts, no_label):
        for i in range(len(no_label)):
            if no_label[i]:
                hand_verts[:, self.label_vertex[i], :] = 0.
        return hand_verts

    def fit_mano_2_pcl(
            self,
            samples,  # samples is a (N_v x 3) array, unit is Millimeter!
            labels,  # labels is a (N_v x 1) array
            seeds=8,
            coarse_iter=50,
            fine_iter=50,
            stop_loss=5.0,
            verbose=0):
        # classify samples according to labels
        palm = samples[(np.where(labels == 0))[0], :]
        thumb = samples[(np.where(labels == 1))[0], :]
        index = samples[(np.where(labels == 2))[0], :]
        middle = samples[(np.where(labels == 3))[0], :]
        ring = samples[(np.where(labels == 4))[0], :]
        pinky = samples[(np.where(labels == 5))[0], :]
        for lab in [palm, thumb, index, middle, ring, pinky]:
            # print(len(lab))
            # print(lab.shape)
            self.no_label.append(lab.shape[0] == 0)
        # Add temp point (0,0,0) to part with no sample.
        # The loss from these points will be masked out later when calculating loss
        if palm.shape[0] == 0: palm = np.zeros([1, 3])
        if thumb.shape[0] == 0: thumb = np.zeros([1, 3])
        if index.shape[0] == 0: index = np.zeros([1, 3])
        if middle.shape[0] == 0: middle = np.zeros([1, 3])
        if ring.shape[0] == 0: ring = np.zeros([1, 3])
        if pinky.shape[0] == 0: pinky = np.zeros([1, 3])

        # print("No label:", self.no_label)
        self.target_points_np = [palm, thumb, index, middle, ring, pinky]
        self.target_points = [torch.from_numpy(palm).float().to(self.device),\
                            torch.from_numpy(thumb).float().to(self.device),\
                            torch.from_numpy(index).float().to(self.device),\
                            torch.from_numpy(middle).float().to(self.device),\
                            torch.from_numpy(ring).float().to(self.device),\
                            torch.from_numpy(pinky).float().to(self.device)]
        self.target_points_tree = [
            KDTree(palm),
            KDTree(thumb),
            KDTree(index),
            KDTree(middle),
            KDTree(ring),
            KDTree(pinky)
        ]

        # Model para initialization:
        shape = torch.zeros(seeds, 10).float().to(self.device)
        shape.requires_grad_()
        rot = torch.zeros(seeds, 3).float().to(self.device)
        rot.requires_grad_()
        pose = torch.zeros(seeds, self.ncomps).float().to(self.device)
        pose = (0.1 * torch.randn(seeds, self.ncomps)).float().to(self.device)
        pose.requires_grad_()
        trans = torch.from_numpy(samples.mean(0) /
                                 1000.0)  # trans should be in meter
        trans = trans.unsqueeze(0).repeat(seeds, 1).float().to(self.device)
        # trans = (0.1*torch.randn(seeds, 3)).float().to(self.device)
        trans.requires_grad_()

        hand_verts, hand_joints = self.mano_layer(torch.cat((rot, pose), 1),
                                                  shape, trans)

        if self.visualize_interm_result:
            demo.display_mosh(torch.from_numpy(samples).float().unsqueeze(0).expand(seeds, -1, -1),\
                                np.zeros((0,4), dtype = np.int32),
                                {'verts': hand_verts.detach().cpu(),
                                'joints': hand_joints.detach().cpu()}, \
                                mano_faces=self.mano_layer.th_faces.detach().cpu(), \
                                alpha = 0.3)

        # Global optimization
        criteria_loss = nn.MSELoss().to(self.device)
        previous_loss = 1e8
        optimizer = torch.optim.Adam([trans, rot], lr=1e-2)
        print('...Optimizing global transformation...')
        for i in range(0, coarse_iter):
            hand_verts, hand_joints = self.mano_layer(
                torch.cat((rot, pose), 1), shape, trans)
            # Find closest label points:
            closest_indices = self.find_nearest_neighbour_index_from_hand(
                self.target_points_tree, hand_verts, self.no_label)
            closest_points = self.index2points_from_hand(
                closest_indices, torch.zeros_like(hand_verts))

            for j in range(0, 20):
                hand_verts, hand_joints = self.mano_layer(
                    torch.cat((rot, pose), 1), shape, trans)
                hand_verts = self.mask_no_label_from_hand(
                    hand_verts, self.no_label)
                loss = criteria_loss(hand_verts, closest_points)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            loss = criteria_loss(hand_verts, closest_points)
            if verbose >= 1:
                print(i, loss.data)
            if previous_loss - loss.data < 1e-1:
                break
            previous_loss = loss.data.detach()
        # print('After coarse alignment: %6f'%(loss.data))
        if self.visualize_interm_result:
            demo.display_mosh(torch.from_numpy(samples).float().unsqueeze(0).expand(seeds, -1, -1),\
                                np.zeros((0,4), dtype = np.int32),
                                {'verts': hand_verts.detach().cpu(),
                                'joints': hand_joints.detach().cpu()}, \
                                mano_faces=self.mano_layer.th_faces.detach().cpu(), \
                                alpha = 0.3)

        # Local optimization
        previous_loss = 1e8
        optimizer = torch.optim.Adam([trans, rot, pose, shape], lr=1e-2)
        print('...Optimizing hand pose shape and global transformation...')
        for i in range(0, fine_iter):
            hand_verts, hand_joints = self.mano_layer(
                torch.cat((rot, pose), 1), shape, trans)
            # Find closest label points:
            closest_batch_hand_indices, target_points_batch_rearrange = self.find_nearest_neighbour_to_hand(
                self.target_points_np, samples.shape[0], hand_verts,
                self.no_label)
            target_points_batch_rearrange = torch.from_numpy(
                target_points_batch_rearrange).float().to(self.device)

            for j in range(0, 20):
                hand_verts, hand_joints = self.mano_layer(
                    torch.cat((rot, pose), 1), shape, trans)
                hands_batch_rearrange = self.index2points_to_hand(
                    closest_batch_hand_indices, hand_verts)

                w_pose = 100.0
                loss = criteria_loss(
                    hands_batch_rearrange, target_points_batch_rearrange
                ) + w_pose * (pose * pose).mean()  # pose regularizer
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            loss = criteria_loss(hands_batch_rearrange,
                                 target_points_batch_rearrange)

            # if previous_loss - loss.data < 1e-1:
            #     break
            previous_loss = loss.data.detach()

            # Find smallest loss in the #seeds seeds:
            per_seed_error = (
                (hands_batch_rearrange - target_points_batch_rearrange) *
                (hands_batch_rearrange -
                 target_points_batch_rearrange)).mean(2).mean(1)
            min_index = torch.argmin(per_seed_error).detach().cpu().numpy()
            min_error = per_seed_error[min_index]
            if verbose >= 1: print(i, min_error.data)

            if self.visualize_interm_result and i % 40 == 0:
                tmp_arange = np.expand_dims(np.arange(
                    target_points_batch_rearrange.shape[1]),
                                            axis=1)
                link = np.concatenate(
                    (tmp_arange,
                     tmp_arange + target_points_batch_rearrange.shape[1]), 1)
                link = link[0:200, :]
                visual_points = torch.cat(
                    (target_points_batch_rearrange[min_index, :, :],
                     hands_batch_rearrange[min_index, :, :]),
                    dim=0)
                visual_points = visual_points.unsqueeze(0).expand(
                    seeds, -1, -1)

                pass
                demo.display_mosh(visual_points.detach().cpu(),\
                                link,
                                {'verts': hand_verts.detach().cpu(),
                                'joints': hand_joints.detach().cpu()}, \
                                mano_faces=self.mano_layer.th_faces.detach().cpu(), \
                                alpha = 0.3)

            if min_error < stop_loss:
                break

        # print('After fine alignment: %6f'%(loss.data))

        hand_verts, hand_joints = self.mano_layer(torch.cat((rot, pose), 1),
                                                  shape, trans)

        hand_shape = {'vertices': hand_verts.detach().cpu().numpy()[min_index, :, :], \
                        'joints': hand_joints.detach().cpu().numpy()[min_index, :, :], \
                         'faces': self.mano_layer.th_faces.detach().cpu()}

        mano_para = {'rot': rot.detach().cpu().numpy()[min_index, :], \
                    'pose': pose.detach().cpu().numpy()[min_index, :], \
                    'shape': shape.detach().cpu().numpy()[min_index, :], \
                    'trans': trans.detach().cpu().numpy()[min_index, :]}

        return hand_shape, mano_para