Esempio n. 1
0
def get_scaled_orthographic_projection(scale, trans, quat, transpose=False):
    """
    Generate scaled orthographic projection matrices rotation and translation
    for the given scale, translation and rotation in quaternions

    :param device: Device to store the output tensor default cuda
    :param scale: A [B] tensor with the scale values for the batch
    :param trans: A [B, 2] tensor with tx and ty values for the batch
    :param quat: A [B, 4] tensor with quaternion values for the batch
    
    :return: A tuple (rotation, translation)
        rotation - A [B, 3, 3] tensor for rotation
        translation - A [B, 3] tensor for translation
    """

    device = scale.device
    translation = torch.cat(
        (trans,
         torch.ones([trans.size(0), 1], dtype=torch.float, device=device) * 5),
        dim=1)

    scale_matrix = torch.zeros((scale.size(0), 3, 3), device=device)
    scale_matrix[:, 0, 0] = scale
    scale_matrix[:, 1, 1] = scale
    scale_matrix[:, 2, 2] = scale

    rotation = quaternion_to_matrix(quat)
    if transpose:
        rotation = rotation.permute(0, 2, 1)
    rotation = torch.matmul(scale_matrix, rotation)

    return rotation, translation
Esempio n. 2
0
def axis_angle_loss(axis_angle, quat):
    output = tr.axis_angle_to_matrix(axis_angle)
    output = torch.transpose(output, 1, 2)
    q2c = torch.zeros(quat.shape)
    q2c[:, 0] = quat[:, 3]
    q2c[:, 1:] = quat[:, 0:3]
    label = tr.quaternion_to_matrix(q2c)
    diff = torch.acos((torch.diagonal(
        torch.matmul(output, label), dim1=-2, dim2=-1).sum(-1) - 1) / 2)
    return diff.mean()
    def forward(self, source, target, fsource, ftarget):
        '''
            Input: point cloud (B, N, 3; unused) and feature (B, F, N)
            Output: rotation (B, 3, 3) and translation (B, 3)
        '''
        x = torch.cat([fsource, ftarget], dim=1)
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 2*self.embed_dim) # (B, 2F)

        for i in range(self.n_mlp_layers):
            x = F.relu(self.bn[i](self.fc[i](x)))

        # Rotation (B, 4) -> (B, 3, 3)
        rot = self.rotation(x)
        rot = rot / torch.linalg.norm(rot, dim=1, keepdim=True)
        rot = transforms.quaternion_to_matrix(rot)

        # Translation (B ,3)
        tr = self.translation(x)
        return rot, tr
Esempio n. 4
0
def _make_node_transform(node: Dict[str, Any]) -> Transform3d:
    """
    Convert a transform from the json data in to a PyTorch3D
    Transform3d format.
    """
    array = node.get("matrix")
    if array is not None:  # Stored in column-major order
        M = np.array(array, dtype=np.float32).reshape(4, 4, order="F")
        return Transform3d(matrix=torch.from_numpy(M))

    out = Transform3d()

    # Given some of (scale/rotation/translation), we do them in that order to
    # get points in to the world space.
    # See https://github.com/KhronosGroup/glTF/issues/743 .

    array = node.get("scale", None)
    if array is not None:
        scale_vector = torch.FloatTensor(array)
        out = out.scale(scale_vector[None])

    # Rotation quaternion (x, y, z, w) where w is the scalar
    array = node.get("rotation", None)
    if array is not None:
        x, y, z, w = array
        # We negate w. This is equivalent to inverting the rotation.
        # This is needed as quaternion_to_matrix makes a matrix which
        # operates on column vectors, whereas Transform3d wants a
        # matrix which operates on row vectors.
        rotation_quaternion = torch.FloatTensor([-w, x, y, z])
        rotation_matrix = quaternion_to_matrix(rotation_quaternion)
        out = out.rotate(R=rotation_matrix)

    array = node.get("translation", None)
    if array is not None:
        translation_vector = torch.FloatTensor(array)
        out = out.translate(x=translation_vector[None])

    return out
Esempio n. 5
0
def build_rotation(rotation, format="matrix") -> Rotation:
    """ Convert roation (with format) into Rotation.
    format: matrix, ortho6d, quat, euler
    """
    # 1. CONVERT SPECIFIED FORMAT TO MATRIX FIRST
    if format == "matrix":
        matrix = rotation
    elif format == "ortho6d":
        matrix = compute_rotation_matrix_from_ortho6d(rotation)
    elif format == "euler":
        matrix = euler_angles_to_matrix(rotation, convention="XYZ")
    elif format == "quat":
        matrix = quaternion_to_matrix(rotation)
    else:
        raise TypeError

    # 2. BUILD ROTATION
    return Rotation(
        ortho6d=rotation if format == "ortho6d" else
        compute_ortho6d_from_rotation_matrix(matrix),
        quat=rotation if format == "quat" else matrix_to_quaternion(matrix),
        matrix=rotation if format == "matrix" else matrix,
        euler=rotation if format == "euler" else matrix_to_euler_angles(
            matrix, convention="XYZ"))
    def forward(self, inputs: List[Dict[Hashable, th.Tensor]]
                ) -> List[List[Tuple[th.Tensor, th.Tensor]]]:
        inputs0 = inputs.copy()

        # NOTE(ycho): custom transform + crop-aware collation.
        inputs = [self.transform(x) for x in inputs]
        inputs = collate_cropped_img(inputs)
        if (Schema.CROPPED_IMAGE not in inputs or
                len(inputs[Schema.CROPPED_IMAGE]) <= 0):
            return None

        # if we use the crop-aware collation, (batch_idx, instance_idx)
        indices = inputs[Schema.INDEX]
        dim, quat = self.model(inputs[Schema.CROPPED_IMAGE].to(
            self.device))

        # NOTE(ycho): len(image) appropriated for batch_size
        batch_size = len(inputs0)
        outputs = [[] for _ in range(batch_size)]

        for i, (ii, s, q) in enumerate(zip(indices, dim, quat)):
            batch_index, instance_index = ii
            P = inputs0[batch_index][Schema.PROJECTION].reshape(4, 4)
            R = quaternion_to_matrix(q[None])[0]

            #R2 = inputs0[batch_index][
            #    Schema.ORIENTATION][instance_index].reshape(
            #    3, 3)
            #q2 = matrix_to_quaternion(R2[None])[0]
            #s2 = inputs0[batch_index][Schema.SCALE][instance_index]

            # Fix BOX_2D convention.
            box_i, box_j, box_h, box_w = inputs[Schema.BOX_2D][i]
            box_2d = th.as_tensor([box_i, box_j, box_i + box_h, box_j + box_w])
            box_2d = 2.0 * (box_2d - 0.5)

            # Solve translation
            translation, _ = self.solve_translation({
                # inputs from dataset
                Schema.PROJECTION: P,
                Schema.BOX_2D: box_2d,
                # inputs from network
                Schema.ORIENTATION: R,
                Schema.QUATERNION: q,
                Schema.SCALE: s
                # inputs from dataset (ground-truth)
                # Schema.ORIENTATION: R2,
                # Schema.QUATERNION: q2,
                # Schema.SCALE: s2
            })

            translation = th.as_tensor(translation, device=R.device)
            if translation[-1] > 0:
                translation *= -1.0
            #print('tr-solve')
            #print(translation)
            #print('tr-gt')
            #print(inputs0[batch_index][Schema.TRANSLATION][instance_index])

            # Convert to box-points
            box_out = self.box_points({
                Schema.ORIENTATION: R.to(self.device),
                Schema.TRANSLATION: translation.to(self.device),
                Schema.SCALE: s.to(self.device),
                Schema.PROJECTION: P.to(self.device),
                Schema.INSTANCE_NUM: 1
            })

            entry = (
                box_out[Schema.KEYPOINT_2D][0, ..., :2].detach().cpu().numpy(),
                box_out[Schema.KEYPOINT_3D][0].detach().cpu().numpy()
            )
            outputs[batch_index].append(entry)
        return outputs
Esempio n. 7
0
    def __call__(
            self, inputs: Dict[Hashable,
                               th.Tensor]) -> Dict[Hashable, th.Tensor]:
        proj_matrix = inputs[Schema.PROJECTION]
        # NOTE(ycho): BOX_2D = (i0, j0, i1, j1) in normalized coords (-1,1).
        box_2d = inputs[Schema.BOX_2D]

        # NOTE(ycho): outputs from network
        dimension = inputs[Schema.SCALE]
        if Schema.ORIENTATION in inputs:
            R = inputs[Schema.ORIENTATION].detach().cpu().numpy()
        elif Schema.QUATERNION in inputss:
            quaternion = inputs[Schema.QUATERNION]
            R = (quaternion_to_matrix(
                th.as_tensor(quaternion)).detach().cpu().numpy())
        else:
            raise KeyError('Orientation information Not Found!')

        vertices = (self.points.cpu() * dimension.cpu()).detach().numpy()

        if True:
            # Reduce the number of permutations through geometric reasoning.
            fovs = 2.0 * np.arctan(1.0 / proj_matrix[[0, 1], [0, 1]])
            with warnings.catch_warnings():
                warnings.filterwarnings(action='ignore',
                                        category=LinAlgWarning)
                warnings.filterwarnings(action='ignore',
                                        category=OptimizeWarning)
                warnings.filterwarnings(action='ignore',
                                        category=np.VisibleDeprecationWarning)
                warnings.filterwarnings(action='ignore',
                                        category=RuntimeWarning)
                perms = compute_feasible_permutations(vertices @ R.T, fovs,
                                                      self.debug_chull)
            perms = np.asarray(list(perms), dtype=np.int32)
            constraints = vertices[perms, :]
        else:
            constraints = list(itertools.permutations(vertices, 4))

        # Initialize current best candidates.
        best_loc = None
        best_error = np.inf
        best_X = None

        # Loop through each possible constraint, hold on to the best guess
        K = proj_matrix.detach().cpu().numpy()

        # Create design matrices Ax=b for SVD.
        # K_ax is the axes of K repeated for each corresponding spatial axis.
        K_ax = K[(0, 1, 0, 1), :3]
        # TODO(ycho): use integer permutations directly and index into the array,
        # instead of creating (large) redundant copies
        for X in constraints:
            A = np.einsum('n,a->na', box_2d, K[2, :3]) - K_ax
            b = -np.einsum('na,ab,nb->n', A, R, X)

            # Solve here with least squares since overparameterized.
            # NOTE(ycho): `error` here indicates algebraic error;
            # it's generally preferable to use geometric error.
            loc, error, rank, s = np.linalg.lstsq(A, b, rcond=None)

            # Evaluate solution ...
            if self.recompute_error:
                # NOTE(ycho): evaluate error based on match with box.
                # FIXME(ycho): This probably results in much more expensive
                # evaluation.
                args = {
                    Schema.ORIENTATION: th.as_tensor(R).detach().cpu(),
                    Schema.TRANSLATION: th.as_tensor(loc).detach().cpu(),
                    Schema.SCALE: th.as_tensor(dimension).detach().cpu(),
                    Schema.PROJECTION: th.as_tensor(K).detach().cpu(),
                    Schema.INSTANCE_NUM: 1
                }
                out_points = self.box_points(args)[Schema.KEYPOINT_2D][..., :2]
                out_points = th.flip(out_points, dims=(-1, ))  # XY->IJ
                out_points = 2.0 * (out_points - 0.5)  # (0,1) -> (-1, +1)
                pmin = out_points.min(dim=-2).values.reshape(-1)
                pmax = out_points.max(dim=-2).values.reshape(-1)
                out_box_2d = th.cat([pmin, pmax])
                error2 = th.norm(box_2d - out_box_2d.to(box_2d.device))
                error = error2

            # Update estimate with better alternative.
            if (error < best_error):
                best_loc = loc
                best_error = error
                best_X = X

        return best_loc, best_X
Esempio n. 8
0
def main():
    # data
    transform = Compose([
        CropObject(CropObject.Settings()),
        Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE, )))
    ])
    _, test_loader = get_loaders(DatasetSettings(),
                                 th.device('cpu'),
                                 1,
                                 transform=transform,
                                 collate_fn=collate_cropped_img)
    # model
    device = th.device('cuda')
    model = load_model()
    model = model.to(device)
    model.eval()

    # translation solver?
    solve_translation = SolveTranslation()

    box_points = BoxPoints2D(th.device('cpu'), Schema.KEYPOINT_2D)
    draw_bbox = DrawBoundingBoxFromKeypoints(
        DrawBoundingBoxFromKeypoints.Settings())

    # eval
    for data in test_loader:
        # Skip occasional batches without any images.
        if Schema.CROPPED_IMAGE not in data:
            continue

        with th.no_grad():
            # run inference
            crop_img = data[Schema.CROPPED_IMAGE].view(-1, 3, 224, 224)
            dim, quat = model(crop_img.to(device))
            dim2, quat2 = data[Schema.SCALE], data[Schema.QUATERNION]
            logging.debug('D {} {}'.format(dim, dim2))
            logging.debug('Q {} {}'.format(quat, quat2))
            # trans = data[Schema.TRANSLATION]

            if False:
                dim = dim2
                quat = quat2
                R = quaternion_to_matrix(quat)

            R = quaternion_to_matrix(quat)

            input_image = data[Schema.IMAGE].detach().cpu()
            proj_matrix = (data[Schema.PROJECTION].detach().cpu().reshape(
                -1, 4, 4))

            # Solve translations.
            translations = []
            for i in range(len(proj_matrix)):
                box_i, box_j, box_h, box_w = data[Schema.BOX_2D][i]
                box_2d = th.as_tensor(
                    [box_i, box_j, box_i + box_h, box_j + box_w])
                box_2d = 2.0 * (box_2d - 0.5)
                args = {
                    # inputs from dataset
                    Schema.PROJECTION: proj_matrix[i],
                    Schema.BOX_2D: box_2d,
                    # inputs from network
                    Schema.ORIENTATION: R[i],
                    Schema.QUATERNION: quat[i],
                    Schema.SCALE: dim[i]
                }
                # Solve translation
                translation, _ = solve_translation(args)
                translations.append(translation)
            translations = th.as_tensor(translations, dtype=th.float32)

            if True:
                print('num instances = {}'.format(len(translations)))
                pred_data = {
                    Schema.IMAGE: data[Schema.IMAGE][0],
                    Schema.ORIENTATION: R.cpu(),
                    Schema.TRANSLATION: translations,
                    Schema.SCALE: dim.cpu(),
                    Schema.PROJECTION: proj_matrix[0],
                    Schema.INSTANCE_NUM: len(proj_matrix),
                }
                pred_data = box_points(pred_data)
                pred_data = draw_bbox(pred_data)
                image_with_box = pred_data['img_w_bbox']
            else:
                dimensions = dim.detach().cpu()
                quaternion = quat.detach().cpu()
                translations = translations.detach().cpu()

                #print(input_image.shape)
                #print(data[Schema.BOX_2D].shape)
                #print(proj_matrix.shape)
                #print(translations.shape)
                #print(dimensions.shape)
                #print(quaternion.shape)

                # draw box
                image_with_box = plot_regressed_3d_bbox(
                    input_image,
                    # keypoints_2d,
                    # data[Schema.BOX_2D],
                    data[Schema.KEYPOINT_2D],
                    proj_matrix,
                    dimensions,
                    quaternion,
                    translations)

            plt.clf()
            plt.imshow(image_with_box.permute(1, 2, 0))
            plt.pause(0.1)
Esempio n. 9
0
    def run_optimization(
        self,
        silhouettes: torch.tensor,
        R: torch.tensor,
        T: torch.tensor,
        writer=None,
        camera_settings=None,
        step: int = 0,
    ):
        """
        Function:
            Runs a batched optimization procedure that aims to minimize 3 reconstruction losses:
                -Silhouette IoU Loss: between input silhouettes and re-projected mesh
                -Mesh Edge consistency
                -Mesh Normal smoothing
            Mini Batching:
                If the number silhouettes is greater than the allowed batch size then a random set of images/poses is sampled for supervision at each step
        Returns:
            -Reconstruction losses: 3 reconstruction losses measured during optimization
            -Timing:
                -Iterations / second
                -Total time elapsed in seconds
        """

        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        tf_smaller = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.params.img_size),
            transforms.ToTensor(),
        ])

        images_gt = torch.stack([
            tf_smaller(s.cpu()).to(self.device) for s in silhouettes
        ]).squeeze(1)

        if images_gt.max() > 1.0:
            images_gt = images_gt / 255.0

        loop = tqdm_notebook(range(self.params.mesh_steps))

        start_time = time.time()
        for i in loop:
            batch_indices = (random.choices(list(range(images_gt.shape[0])),
                                            k=self.params.mesh_batch_size) if
                             images_gt.shape[0] > self.params.mesh_batch_size
                             else list(range(images_gt.shape[0])))
            batch_silhouettes = images_gt[batch_indices]

            batch_R, batch_T = R[batch_indices], T[batch_indices]
            # apply right transform on the Twv to adjust the coordinate system shift from EVIMO to PyTorch3D
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = (torch.tensor([
                    camera_settings[0, 0], camera_settings[1, 1]
                ])[None]).expand(batch_R.shape[0], 2)
                principle_point = (torch.tensor([
                    camera_settings[0, 2], camera_settings[1, 2]
                ])[None]).expand(batch_R.shape[0], 2)
                # FIXME: in this PyTorch3D version, the image_size in RasterizationSettings is (W, H), while in PerspectiveCameras is (H, W)
                # If the future pytorch3d change the format, please change the settings here
                # We hope PyTorch3D will solve this issue in the future
                batch_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                batch_cameras = PerspectiveCameras(device=self.device,
                                                   R=batch_R,
                                                   T=batch_T)

            mesh, laplacian_loss, flatten_loss = self.forward(
                self.params.mesh_batch_size)

            images_pred = self.renderer(mesh,
                                        device=self.device,
                                        cameras=batch_cameras)[..., -1]

            iou_loss = IOULoss().forward(batch_silhouettes, images_pred)

            loss = (iou_loss * self.params.lambda_iou +
                    laplacian_loss * self.params.lambda_laplacian +
                    flatten_loss * self.params.lambda_flatten)

            loop.set_description("Optimizing (loss %.4f)" % loss.data)

            self.losses["iou"].append(iou_loss * self.params.lambda_iou)
            self.losses["laplacian"].append(laplacian_loss *
                                            self.params.lambda_laplacian)
            self.losses["flatten"].append(flatten_loss *
                                          self.params.lambda_flatten)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if i % (self.params.mesh_show_step /
                    2) == 0 and self.params.mesh_log:
                logging.info(
                    f'Iteration: {i} IOU Loss: {iou_loss.item()} Flatten Loss: {flatten_loss.item()} Laplacian Loss: {laplacian_loss.item()}'
                )

            if i % self.params.mesh_show_step == 0 and self.params.im_show:
                # Write images
                image = images_pred.detach().cpu().numpy()[0]

                if writer:
                    writer.append_data((255 * image).astype(np.uint8))
                plt.imshow(images_pred.detach().cpu().numpy()[0])
                plt.show()
                plt.imshow(batch_silhouettes.detach().cpu().numpy()[0])
                plt.show()
                plot_pointcloud(mesh[0], 'Mesh deformed')
                logging.info(
                    f'Pose of init camera: {self.init_camera_R.detach().cpu().numpy()}, {self.init_camera_t.detach().cpu().numpy()}'
                )

        # Set the final optimized mesh as an internal variable
        self.final_mesh = mesh[0].clone()
        results = dict(
            silhouette_loss=self.losses["iou"]
            [-1].detach().cpu().numpy().tolist(),
            laplacian_loss=self.losses["laplacian"]
            [-1].detach().cpu().numpy().tolist(),
            flatten_loss=self.losses["flatten"]
            [-1].detach().cpu().numpy().tolist(),
            iterations_per_second=self.params.mesh_steps /
            (time.time() - start_time),
            total_time_s=time.time() - start_time,
        )
        if self.is_real_data:
            self.init_pose_R = self.init_camera_R.detach().cpu().numpy()
            self.init_pose_t = self.init_camera_t.detach().cpu().numpy()

        torch.cuda.empty_cache()

        return results
Esempio n. 10
0
    def render_final_mesh(self,
                          poses,
                          mode: str,
                          out_size: list,
                          camera_settings=None) -> dict:
        """Renders the final mesh obtained through optimization
            Supports two modes:
                -predict: renders both silhouettes and flat shaded images
                -train: only renders silhouettes
            Returns:
                -dict of renders {'silhouettes': tensor, 'images': tensor}
        """
        R, T = poses
        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        sil_renderer = silhouette_renderer(out_size, self.device)
        image_renderer = flat_renderer(out_size, self.device)

        # Create a silhouette projection of the mesh across all views
        all_silhouettes = []
        all_images = []
        for i in range(0, R.shape[0]):
            batch_R, batch_T = R[[i]], T[[i]]
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = torch.tensor(
                    [camera_settings[0, 0], camera_settings[1, 1]])[None]
                principle_point = torch.tensor(
                    [camera_settings[0, 2], camera_settings[1, 2]])[None]
                t_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                t_cameras = PerspectiveCameras(device=self.device,
                                               R=batch_R,
                                               T=batch_T)
            all_silhouettes.append(
                sil_renderer(self._final_mesh,
                             device=self.device,
                             cameras=t_cameras).detach().cpu()[..., -1])

            if mode == "predict":
                all_images.append(
                    torch.clamp(
                        image_renderer(self._final_mesh,
                                       device=self.device,
                                       cameras=t_cameras),
                        0,
                        1,
                    ).detach().cpu()[..., :3])
            torch.cuda.empty_cache()
        renders = dict(
            silhouettes=torch.cat(all_silhouettes).unsqueeze(-1).permute(
                0, 3, 1, 2),
            images=torch.cat(all_images) if all_images else [],
        )

        return renders
Esempio n. 11
0
        points = points.bmm(r)
        points = points + t[:, None, :]

    return points


if __name__ == "__main__":
    rand_pt = torch.randn(4, 1000, 3)

    rand_qt = random_qt(4, 0.5, 3)

    # qt -> rt
    q = rand_qt[:, :4]
    t = rand_qt[:, 4:, None]

    R = pt3d_T.quaternion_to_matrix(q)
    Rinv = pt3d_T.quaternion_to_matrix(pt3d_T.quaternion_invert(q))

    Rti = torch.cat((Rinv, t), dim=2)
    Rt = torch.cat((R, t), dim=2)

    rot_qt = transform_points_qt(rand_pt, rand_qt)
    rot_Rt = transform_points_Rt(rand_pt, Rt)
    rot_Rti = transform_points_Rt(rand_pt, Rti)

    qt_Rt = (rot_qt - rot_Rt).norm(dim=2, p=2).mean()
    qt_Rti = (rot_qt - rot_Rti).norm(dim=2, p=2).mean()
    Rt_Rti = (rot_Rti - rot_Rt).norm(dim=2, p=2).mean()

    print(f"|| points ||:    {rand_pt.norm(p=2,dim=2).mean()}")
    print(f"Diff Rt and qt:  {qt_Rt:.4e}")
Esempio n. 12
0
    def forward(self,
                seq,
                msa=None,
                mask=None,
                msa_mask=None,
                extra_msa=None,
                extra_msa_mask=None,
                seq_index=None,
                seq_embed=None,
                msa_embed=None,
                templates_feats=None,
                templates_mask=None,
                templates_angles=None,
                embedds=None,
                recyclables=None,
                return_trunk=False,
                return_confidence=False,
                return_recyclables=False,
                return_aux_logits=False):
        assert not (
            self.disable_token_embed and not exists(seq_embed)
        ), 'sequence embedding must be supplied if one has disabled token embedding'
        assert not (
            self.disable_token_embed and not exists(msa_embed)
        ), 'msa embedding must be supplied if one has disabled token embedding'

        # if MSA is not passed in, just use the sequence itself

        if not exists(msa):
            msa = rearrange(seq, 'b n -> b () n')
            msa_mask = rearrange(mask, 'b n -> b () n')

        # assert on sequence length

        assert msa.shape[-1] == seq.shape[
            -1], 'sequence length of MSA and primary sequence must be the same'

        # variables

        b, n, device = *seq.shape[:2], seq.device
        n_range = torch.arange(n, device=device)

        # unpack (AA_code, atom_pos)

        if isinstance(seq, (list, tuple)):
            seq, seq_pos = seq

        # embed main sequence

        x = self.token_emb(seq)

        if exists(seq_embed):
            x += seq_embed

        # mlm for MSAs

        if self.training and exists(msa):
            original_msa = msa
            msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool())

            noised_msa, replaced_msa_mask = self.mlm.noise(msa, msa_mask)
            msa = noised_msa

        # embed multiple sequence alignment (msa)

        if exists(msa):
            m = self.token_emb(msa)

            if exists(msa_embed):
                m = m + msa_embed

            # add single representation to msa representation

            m = m + rearrange(x, 'b n d -> b () n d')

            # get msa_mask to all ones if none was passed
            msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool())

        elif exists(embedds):
            m = self.embedd_project(embedds)

            # get msa_mask to all ones if none was passed
            msa_mask = default(
                msa_mask, lambda: torch.ones_like(embedds[..., -1]).bool())
        else:
            raise Error('either MSA or embeds must be given')

        # derive pairwise representation

        x_left, x_right = self.to_pairwise_repr(x).chunk(2, dim=-1)
        x = rearrange(x_left, 'b i d -> b i () d') + rearrange(
            x_right, 'b j d-> b () j d')  # create pair-wise residue embeds
        x_mask = rearrange(mask, 'b i -> b i ()') * rearrange(
            mask, 'b j -> b () j') if exists(mask) else None

        # add relative positional embedding

        seq_index = default(seq_index, lambda: torch.arange(n, device=device))
        seq_rel_dist = rearrange(seq_index, 'i -> () i ()') - rearrange(
            seq_index, 'j -> () () j')
        seq_rel_dist = seq_rel_dist.clamp(
            -self.max_rel_dist, self.max_rel_dist) + self.max_rel_dist
        rel_pos_emb = self.pos_emb(seq_rel_dist)

        x = x + rel_pos_emb

        # add recyclables, if present

        if exists(recyclables):
            m[:, 0] = m[:, 0] + self.recycling_msa_norm(
                recyclables.single_msa_repr_row)
            x = x + self.recycling_pairwise_norm(recyclables.pairwise_repr)

            distances = torch.cdist(recyclables.coords,
                                    recyclables.coords,
                                    p=2)
            boundaries = torch.linspace(2,
                                        20,
                                        steps=self.recycling_distance_buckets,
                                        device=device)
            discretized_distances = torch.bucketize(distances, boundaries[:-1])
            distance_embed = self.recycling_distance_embed(
                discretized_distances)

            x = x + distance_embed

        # embed templates, if present

        if exists(templates_feats):
            _, num_templates, *_ = templates_feats.shape

            # embed template

            t = self.to_template_embed(templates_feats)
            t_mask_crossed = rearrange(templates_mask,
                                       'b t i -> b t i ()') * rearrange(
                                           templates_mask, 'b t j -> b t () j')

            t = rearrange(t, 'b t ... -> (b t) ...')
            t_mask_crossed = rearrange(t_mask_crossed, 'b t ... -> (b t) ...')

            for _ in range(self.templates_embed_layers):
                t = self.template_pairwise_embedder(t, mask=t_mask_crossed)

            t = rearrange(t, '(b t) ... -> b t ...', t=num_templates)
            t_mask_crossed = rearrange(t_mask_crossed,
                                       '(b t) ... -> b t ...',
                                       t=num_templates)

            # template pos emb

            x_point = rearrange(x, 'b i j d -> (b i j) () d')
            t_point = rearrange(t, 'b t i j d -> (b i j) t d')
            x_mask_point = rearrange(x_mask, 'b i j -> (b i j) ()')
            t_mask_point = rearrange(t_mask_crossed, 'b t i j -> (b i j) t')

            template_pooled = self.template_pointwise_attn(
                x_point,
                context=t_point,
                mask=x_mask_point,
                context_mask=t_mask_point)

            template_pooled_mask = rearrange(
                t_mask_point.sum(dim=-1) > 0, 'b -> b () ()')
            template_pooled = template_pooled * template_pooled_mask

            template_pooled = rearrange(template_pooled,
                                        '(b i j) () d -> b i j d',
                                        i=n,
                                        j=n)
            x = x + template_pooled

        # add template angle features to MSAs by passing through MLP and then concat

        if exists(templates_angles):
            t_angle_feats = self.template_angle_mlp(templates_angles)
            m = torch.cat((m, t_angle_feats), dim=1)
            msa_mask = torch.cat((msa_mask, templates_mask), dim=1)

        # embed extra msa, if present

        if exists(extra_msa):
            extra_m = self.token_emb(msa)
            extra_msa_mask = default(extra_msa_mask,
                                     torch.ones_like(extra_m).bool())

            x, extra_m = self.extra_msa_evoformer(x,
                                                  extra_m,
                                                  mask=x_mask,
                                                  msa_mask=extra_msa_mask)

        # trunk

        x, m = self.net(x, m, mask=x_mask, msa_mask=msa_mask)

        # ready output container

        ret = ReturnValues()

        # calculate theta and phi before symmetrization

        if self.predict_angles:
            ret.theta_logits = self.to_prob_theta(x)
            ret.phi_logits = self.to_prob_phi(x)

        # embeds to distogram

        trunk_embeds = (x +
                        rearrange(x, 'b i j d -> b j i d')) * 0.5  # symmetrize
        distance_pred = self.to_distogram_logits(trunk_embeds)
        ret.distance = distance_pred

        # calculate mlm loss, if training

        msa_mlm_loss = None
        if self.training and exists(msa):
            num_msa = original_msa.shape[1]
            msa_mlm_loss = self.mlm(m[:, :num_msa], original_msa,
                                    replaced_msa_mask)

        # determine angles, if specified

        if self.predict_angles:
            omega_input = trunk_embeds if self.symmetrize_omega else x
            ret.omega_logits = self.to_prob_omega(omega_input)

        if not self.predict_coords or return_trunk:
            return ret

        # derive single and pairwise embeddings for structural refinement

        single_msa_repr_row = m[:, 0]

        single_repr = self.msa_to_single_repr_dim(single_msa_repr_row)
        pairwise_repr = self.trunk_to_pairwise_repr_dim(x)

        # prepare float32 precision for equivariance

        original_dtype = single_repr.dtype
        single_repr, pairwise_repr = map(lambda t: t.float(),
                                         (single_repr, pairwise_repr))

        # iterative refinement with equivariant transformer in high precision

        with torch_default_dtype(torch.float32):

            quaternions = torch.tensor([1., 0., 0., 0.],
                                       device=device)  # initial rotations
            quaternions = repeat(quaternions, 'd -> b n d', b=b, n=n)
            translations = torch.zeros((b, n, 3), device=device)

            # go through the layers and apply invariant point attention and feedforward

            for i in range(self.structure_module_depth):
                is_last = i == (self.structure_module_depth - 1)

                # the detach comes from
                # https://github.com/deepmind/alphafold/blob/0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408/alphafold/model/folding.py#L383
                rotations = quaternion_to_matrix(quaternions)

                if not is_last:
                    rotations = rotations.detach()

                single_repr = self.ipa_block(single_repr,
                                             mask=mask,
                                             pairwise_repr=pairwise_repr,
                                             rotations=rotations,
                                             translations=translations)

                # update quaternion and translation

                quaternion_update, translation_update = self.to_quaternion_update(
                    single_repr).chunk(2, dim=-1)
                quaternion_update = F.pad(quaternion_update, (1, 0), value=1.)

                quaternions = quaternion_multiply(quaternions,
                                                  quaternion_update)
                translations = translations + einsum(
                    'b n c, b n c r -> b n r', translation_update, rotations)

            points_local = self.to_points(single_repr)
            rotations = quaternion_to_matrix(quaternions)
            coords = einsum('b n c, b n c d -> b n d', points_local,
                            rotations) + translations

        coords.type(original_dtype)

        if return_recyclables:
            coords, single_msa_repr_row, pairwise_repr = map(
                torch.detach, (coords, single_msa_repr_row, pairwise_repr))
            ret.recyclables = Recyclables(coords, single_msa_repr_row,
                                          pairwise_repr)

        if return_aux_logits:
            return coords, ret

        if return_confidence:
            return coords, self.lddt_linear(single_repr.float())

        return coords