def get_vertices(self, shape_param, exp_param, batch_size):
        """
        generate vertices from shape_para and exp_para
        :param: shape_para: [batch, n_shape_para, 1] or [batch, n_shape_para]
        :param: exp_para:   [batch, n_exp_para, 1] or [batch, n_exp_para]
        :param: batch_size:
        :return: vertices:  [batch, n_vertices, 3]
        """

        assert is_tf_expression(shape_param) and is_tf_expression(exp_param)

        sp_shape = tf.shape(shape_param)
        if sp_shape.shape[0] == 2:
            tf.debugging.assert_shapes(
                [(shape_param, (batch_size, self.n_shape_para))],
                message='shape_param shape wrong, dim != ({batch}, {dim})'.
                format(batch=batch_size, dim=self.n_shape_para))
            shape_param = tf.expand_dims(shape_param, 2)
        elif sp_shape.shape[0] == 3:
            tf.debugging.assert_shapes(
                [(shape_param, (batch_size, self.n_shape_para, 1))],
                message='shape_param shape wrong, dim != ({batch}, {dim}, 1)'.
                format(batch=batch_size, dim=self.n_shape_para))
        else:
            raise ValueError(
                'shape_param shape wrong, dim != ({batch}, {dim}, 1) or ({batch}, {dim})'
                .format(batch=batch_size, dim=self.n_shape_para))

        ep_shape = tf.shape(exp_param)
        if ep_shape.shape[0] == 2:
            tf.debugging.assert_shapes(
                [(exp_param, (batch_size, self.n_exp_para))],
                message='exp_param shape wrong, dim != ({batch}, {dim})'.
                format(batch=batch_size, dim=self.n_exp_para))
            exp_param = tf.expand_dims(exp_param, 2)
        elif ep_shape.shape[0] == 3:
            tf.debugging.assert_shapes(
                [(exp_param, (batch_size, self.n_exp_para, 1))],
                message='exp_param shape wrong, dim != ({batch}, {dim}, 1)'.
                format(batch=batch_size, dim=self.n_exp_para))
        else:
            raise ValueError(
                'exp_param shape wrong, dim != ({batch}, {dim}, 1) or ({batch}, {dim})'
                .format(batch=batch_size, dim=self.n_exp_para))

        vertices = tf.expand_dims(self.shape_mu, 0) + tf.einsum(
            'ij,kjs->kis', self.shape_pc, shape_param) + tf.einsum(
                'ij,kjs->kis', self.exp_pc, exp_param)

        vertices = tf.reshape(vertices, (batch_size, self.n_vertices, 3))
        return vertices
    def _get_texture(self, tex_param, batch_size):
        """
        generate texture using tex_Para
        :param tex_param: [batch, 40, 1] or [batch, 40]
        :return: tex: [n_vertices, 3]
        """
        assert is_tf_expression(tex_param)

        tp_shape = tf.shape(tex_param)
        if tp_shape.shape[0] == 2:
            tf.debugging.assert_shapes(
                [(tex_param, (batch_size, self.n_tex_para))],
                message='tex_param shape wrong, dim != ({batch}, {dim})'.
                format(batch=batch_size, dim=self.n_tex_para))
            tex_param = tf.expand_dims(tex_param, 2)
        elif tp_shape.shape[0] == 3:
            tf.debugging.assert_shapes(
                [(tex_param, (batch_size, self.n_tex_para, 1))],
                message='tex_param shape wrong, dim != ({batch}, {dim}, 1)'.
                format(batch=batch_size, dim=self.n_tex_para))
        else:
            raise ValueError(
                'tex_param shape wrong, dim != ({batch}, {dim}, 1) or ({batch}, {dim})'
                .format(batch=batch_size, dim=self.n_tex_para))

        tex = tf.expand_dims(self.tex_mu, 0) + tf.einsum(
            'ij,kjs->kis', self.tex_pc, tex_param)
        tex = tf.reshape(tex, (batch_size, self.n_vertices, 3))
        return tex
    def _get_illum(self, illum_param, batch_size):
        """
        genreate illuminate params
        :param illum_param: [batch, 1, 10] or  [batch, 10]
        :return:

        h: [batch, 3, 1]
        ks: [batch, 1]
        v: [batch, 1]
        amb: [batch, 3, 3]
        d: [batch, 3, 3]
        ks: [batch, 1]
        l: [batch, 3, 1]
        """
        assert is_tf_expression(illum_param)

        ip_shape = tf.shape(illum_param)
        if ip_shape.shape[0] == 2:
            tf.debugging.assert_shapes(
                [(illum_param, (batch_size, self.n_illum_para))],
                message='illum_param shape wrong, dim != ({batch}, {dim})'.
                format(batch=batch_size, dim=self.n_illum_para))
            illum_param = tf.expand_dims(illum_param, 1)
        elif ip_shape.shape[0] == 3:
            tf.debugging.assert_shapes(
                [(illum_param, (batch_size, 1, self.n_illum_para))],
                message='illum_param shape wrong, dim != ({batch}, 1, {dim})'.
                format(batch=batch_size, dim=self.n_illum_para))
        else:
            raise ValueError(
                'illum_param shape wrong, dim != ({batch}, 1, {dim}) or ({batch}, {dim})'
                .format(batch=batch_size, dim=self.n_illum_para))

        thetal = illum_param[:, :, 6]
        phil = illum_param[:, :, 7]
        ks = illum_param[:, :, 8]
        v = illum_param[:, :, 9]

        amb = tf.linalg.diag(illum_param[:, 0, 0:3])
        d = tf.linalg.diag(illum_param[:, 0, 3:6])

        # l, (batch, 3)
        l = tf.concat([
            tf.math.cos(thetal) * tf.math.sin(phil),
            tf.math.sin(thetal),
            tf.math.cos(thetal) * tf.math.cos(phil)
        ],
                      axis=1)
        h = l + tf.expand_dims(tf.constant([0, 0, 1], dtype=tf.float32),
                               axis=0)
        h = h / tf.sqrt(tf.reduce_sum(tf.square(h), axis=1, keepdims=True))

        return tf.reshape(h, (batch_size, -1, 1)), ks, v, amb, d, tf.reshape(
            l, (batch_size, -1, 1))
    def _get_color(self, color_param, batch_size):
        """
        # Color_Para: add last value as it's always 1

        generate color from color_para
        :param color_param: [batch, 1, 6] or [batch, 6] or [batch, 1, 7] or [batch, 7]
        the data we use to train the model has constant at position 6, e.g. color_param[:, 0, 6] == 1
        thus, we remove it from our training data
        :returns:
             o: [batch, n_vertices, 3]
             M: constant matrix [3, 3]
             g: diagonal matrix [batch, 3, 3]
             c: [batch, 1]
        """

        assert is_tf_expression(color_param)

        cp_shape = tf.shape(color_param)
        if cp_shape.shape[0] == 2:
            tf.debugging.assert_shapes(
                [(color_param, (batch_size, self.n_color_para))],
                message='color_param shape wrong, dim != ({batch}, {dim})'.
                format(batch=batch_size, dim=self.n_color_para))
            color_param = tf.expand_dims(color_param, 1)
        elif cp_shape.shape[0] == 3:
            tf.debugging.assert_shapes(
                [(color_param, (batch_size, 1, self.n_color_para))],
                message='color_param shape wrong, dim != ({batch}, 1, {dim})'.
                format(batch=batch_size, dim=self.n_color_para))
        else:
            raise ValueError(
                'color_param shape wrong, dim != ({batch}, 1, {dim}) or ({batch}, {dim})'
                .format(batch=batch_size, dim=self.n_color_para))
        # c shape: (batch, 1)
        c = color_param[:, :, 6]
        M = tf.constant(
            [[0.3, 0.59, 0.11], [0.3, 0.59, 0.11], [0.3, 0.59, 0.11]],
            shape=(3, 3))
        g = tf.linalg.diag(color_param[:, 0, 0:3])
        o = tf.reshape(color_param[:, 0, 3:6], (batch_size, 1, 3))
        # o matrix of shape(batch, n_vertices, 3)
        o = tf.tile(o, [1, self.n_vertices, 1])

        tf.debugging.assert_shapes([(o, (batch_size, self.n_vertices, 3))])

        return o, M, g, c
    def get_vertex_colors(self, tex_param, color_param, illum_param,
                          vertex_norm, batch_size):
        """
        generate texture and color for rendering
        :param tex_param: [batch, 199, 1] or [batch, 199]
        :param color_param: [batch, 1, 7] or [batch, 7]
        :param illum_param: [batch, 1, 10] or [batch, 10]
        :param vertex_norm: vertex norm [batch, n_vertex, 3]
        :param batch_size
        :return: texture color [batch, n_vertex, 3]
        """

        assert is_tf_expression(tex_param)
        assert is_tf_expression(vertex_norm)

        tex = self._get_texture(tex_param=tex_param, batch_size=batch_size)

        if color_param is None or illum_param is None:
            return tex

        assert is_tf_expression(color_param)
        assert is_tf_expression(illum_param)
        # o: [batch, n_vertices, 3]
        # M: constant, matrix[3, 3]
        # g: diagonal, matrix[batch, 3, 3]
        # c: 1

        o, M, g, c = self._get_color(color_param=color_param,
                                     batch_size=batch_size)

        # h: [batch, 3, 1]
        # ks: [batch, 1]
        # v: 20.0
        # amb: [batch, 3, 3]
        # d: [batch, 3, 3]
        # ks: [batch, 1]
        # l: [batch, 3, 1]

        h, ks, v, amb, d, l = self._get_illum(illum_param=illum_param,
                                              batch_size=batch_size)
        # n_l of shape (batch, n_ver, 1)
        n_l = tf.einsum('ijk,iks->ijs', vertex_norm, l)
        # n_h of shape (batch, n_ver, 1)
        n_h = tf.einsum('ijk,iks->ijs', vertex_norm, h)
        # n_l of shape (batch, n_ver, 3)
        n_l = tf.tile(n_l, [1, 1, 3])
        # n_h of shape (batch, n_ver, 3)
        n_h = tf.tile(n_h, [1, 1, 3])

        # L of shape (batch, n_ver, 3)
        L = tf.einsum('ijk,iks->ijs', tex, amb) + tf.einsum(
            'ijk,iks->ijs', tf.math.multiply(n_l, tex), d) + tf.expand_dims(
                ks, axis=2) * tf.math.pow(n_h, tf.expand_dims(
                    v, axis=1))  # <-(batch, 1, 1) * (batch, n_ver, 3)

        # c, (batch, 1)
        # tf.tile(c, (1, 3)), (batch, 3)
        # c_expanded, (batch, 3, 3)
        c_expanded = tf.linalg.diag(tf.tile(c, (1, 3)))
        nc_expanded = tf.linalg.diag(tf.tile(1 - c, (1, 3)))
        # CT of shape (batch, 3, 3)
        CT = tf.math.multiply(
            g, c_expanded + tf.einsum('ijk,ks->ijs', nc_expanded, M))

        # vertex_colors: (batch, n_ver, 3)
        vertex_colors = tf.einsum('ijk,iks->ijs', L, CT) + o

        tf.debugging.assert_shapes(
            [(vertex_colors, (batch_size, self.n_vertices, 3))],
            message='vertex_colors shape wrong, dim != ({batch}, {n_vert}, 3)'.
            format(batch=batch_size, n_vert=self.n_vertices))

        return vertex_colors
def render_batch(pose_param, shape_param, exp_param, tex_param, color_param,
                 illum_param, frame_width: int, frame_height: int,
                 tf_bfm: TfMorphableModel, batch_size: int):
    """
    render faces in batch
    :param: pose_param: [batch, n_pose_para] or (batch, 1, n_pose_param)
    :param: shape_param: [batch, n_shape_para, 1] or [batch, n_shape_para]
    :param: exp_param:   [batch, n_exp_para, 1] or [batch, n_exp_para]
    :param: tex_param: [batch, n_tex_para, 1] or [batch, n_tex_para]
    :param: color_param: [batch, 1, n_color_para] or [batch, n_color_para]
    :param: illum_param: [batch, 1, n_illum_para] or [batch, n_illum_para]
    :param: frame_width: rendered image width
    :param: frame_height: rendered image height
    :param: tf_bfm: basel face model
    :param: batch_size: batch size
    :return: images, [batch, frame_width, frame_height, 3]
    """
    assert is_tf_expression(pose_param)

    pose_shape = tf.shape(pose_param)
    if pose_shape.shape[0] == 2:
        tf.debugging.assert_shapes(
            [(pose_param, (batch_size, tf_bfm.get_num_pose_param()))],
            message='pose_param shape wrong, dim != ({batch}, {dim})'.format(
                batch=batch_size, dim=tf_bfm.get_num_pose_param()))
        pose_param = tf.expand_dims(pose_param, 1)
    elif pose_shape.shape[0] == 3:
        tf.debugging.assert_shapes(
            [(pose_param, (batch_size, 1, tf_bfm.get_num_pose_param()))],
            message='pose_param shape wrong, dim != ({batch}, 1, {dim})'.
            format(batch=batch_size, dim=tf_bfm.get_num_pose_param()))
    else:
        raise ValueError(
            'pose_param shape wrong, dim != ({batch}, 1, {dim}) or ({batch}, {dim})'
            .format(batch=batch_size, dim=tf_bfm.get_num_pose_param()))

    vertices = tf_bfm.get_vertices(shape_param=shape_param,
                                   exp_param=exp_param,
                                   batch_size=batch_size)
    # vertex_norm = lighting.vertex_normals(vertices, tf_bfm.triangles)
    vertex_norm = tfg.geometry.representation.mesh.normals.vertex_normals(
        vertices=vertices,
        indices=tf.repeat(tf.expand_dims(tf_bfm.triangles, 0),
                          batch_size,
                          axis=0),
        clockwise=True)

    colors = tf_bfm.get_vertex_colors(tex_param=tex_param,
                                      color_param=color_param,
                                      illum_param=illum_param,
                                      vertex_norm=-vertex_norm,
                                      batch_size=batch_size)

    colors = tf.clip_by_value(colors / 255., 0., 1.)

    transformed_vertices = affine_transform(vertices=vertices,
                                            scaling=pose_param[:, 0, 6:],
                                            angles_rad=pose_param[:, 0, 0:3],
                                            t3d=pose_param[:, 0:, 3:6])
    transformed_vertices_x = transformed_vertices[:, :,
                                                  0] * 2 / frame_width - 1
    transformed_vertices_y = transformed_vertices[:, :,
                                                  1] * 2 / frame_height - 1
    transformed_vertices_z = -transformed_vertices[:, :, 2] / tf.reduce_max(
        tf.abs(transformed_vertices[:, :, 2]))

    # Convert vertices to homogeneous coordinates
    transformed_vertices = tf.concat([
        tf.expand_dims(transformed_vertices_x, axis=2),
        tf.expand_dims(transformed_vertices_y, axis=2),
        tf.expand_dims(transformed_vertices_z, axis=2),
        tf.ones_like(transformed_vertices[:, :, -1:])
    ],
                                     axis=2)

    # Render the G-buffer
    image = dirt.rasterise_batch(
        vertices=transformed_vertices,
        faces=tf.tile(tf.expand_dims(tf_bfm.triangles, axis=0),
                      (batch_size, 1, 1)),
        # faces=tf.expand_dims(tf_bfm.triangles, axis=0),
        vertex_colors=colors,
        background=tf.zeros([batch_size, frame_height, frame_width, 3]),
        width=frame_width,
        height=frame_height,
        channels=3)

    return image * 255