def get_scaled_orthographic_projection(scale, trans, quat, transpose=False): """ Generate scaled orthographic projection matrices rotation and translation for the given scale, translation and rotation in quaternions :param device: Device to store the output tensor default cuda :param scale: A [B] tensor with the scale values for the batch :param trans: A [B, 2] tensor with tx and ty values for the batch :param quat: A [B, 4] tensor with quaternion values for the batch :return: A tuple (rotation, translation) rotation - A [B, 3, 3] tensor for rotation translation - A [B, 3] tensor for translation """ device = scale.device translation = torch.cat( (trans, torch.ones([trans.size(0), 1], dtype=torch.float, device=device) * 5), dim=1) scale_matrix = torch.zeros((scale.size(0), 3, 3), device=device) scale_matrix[:, 0, 0] = scale scale_matrix[:, 1, 1] = scale scale_matrix[:, 2, 2] = scale rotation = quaternion_to_matrix(quat) if transpose: rotation = rotation.permute(0, 2, 1) rotation = torch.matmul(scale_matrix, rotation) return rotation, translation
def axis_angle_loss(axis_angle, quat): output = tr.axis_angle_to_matrix(axis_angle) output = torch.transpose(output, 1, 2) q2c = torch.zeros(quat.shape) q2c[:, 0] = quat[:, 3] q2c[:, 1:] = quat[:, 0:3] label = tr.quaternion_to_matrix(q2c) diff = torch.acos((torch.diagonal( torch.matmul(output, label), dim1=-2, dim2=-1).sum(-1) - 1) / 2) return diff.mean()
def forward(self, source, target, fsource, ftarget): ''' Input: point cloud (B, N, 3; unused) and feature (B, F, N) Output: rotation (B, 3, 3) and translation (B, 3) ''' x = torch.cat([fsource, ftarget], dim=1) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 2*self.embed_dim) # (B, 2F) for i in range(self.n_mlp_layers): x = F.relu(self.bn[i](self.fc[i](x))) # Rotation (B, 4) -> (B, 3, 3) rot = self.rotation(x) rot = rot / torch.linalg.norm(rot, dim=1, keepdim=True) rot = transforms.quaternion_to_matrix(rot) # Translation (B ,3) tr = self.translation(x) return rot, tr
def _make_node_transform(node: Dict[str, Any]) -> Transform3d: """ Convert a transform from the json data in to a PyTorch3D Transform3d format. """ array = node.get("matrix") if array is not None: # Stored in column-major order M = np.array(array, dtype=np.float32).reshape(4, 4, order="F") return Transform3d(matrix=torch.from_numpy(M)) out = Transform3d() # Given some of (scale/rotation/translation), we do them in that order to # get points in to the world space. # See https://github.com/KhronosGroup/glTF/issues/743 . array = node.get("scale", None) if array is not None: scale_vector = torch.FloatTensor(array) out = out.scale(scale_vector[None]) # Rotation quaternion (x, y, z, w) where w is the scalar array = node.get("rotation", None) if array is not None: x, y, z, w = array # We negate w. This is equivalent to inverting the rotation. # This is needed as quaternion_to_matrix makes a matrix which # operates on column vectors, whereas Transform3d wants a # matrix which operates on row vectors. rotation_quaternion = torch.FloatTensor([-w, x, y, z]) rotation_matrix = quaternion_to_matrix(rotation_quaternion) out = out.rotate(R=rotation_matrix) array = node.get("translation", None) if array is not None: translation_vector = torch.FloatTensor(array) out = out.translate(x=translation_vector[None]) return out
def build_rotation(rotation, format="matrix") -> Rotation: """ Convert roation (with format) into Rotation. format: matrix, ortho6d, quat, euler """ # 1. CONVERT SPECIFIED FORMAT TO MATRIX FIRST if format == "matrix": matrix = rotation elif format == "ortho6d": matrix = compute_rotation_matrix_from_ortho6d(rotation) elif format == "euler": matrix = euler_angles_to_matrix(rotation, convention="XYZ") elif format == "quat": matrix = quaternion_to_matrix(rotation) else: raise TypeError # 2. BUILD ROTATION return Rotation( ortho6d=rotation if format == "ortho6d" else compute_ortho6d_from_rotation_matrix(matrix), quat=rotation if format == "quat" else matrix_to_quaternion(matrix), matrix=rotation if format == "matrix" else matrix, euler=rotation if format == "euler" else matrix_to_euler_angles( matrix, convention="XYZ"))
def forward(self, inputs: List[Dict[Hashable, th.Tensor]] ) -> List[List[Tuple[th.Tensor, th.Tensor]]]: inputs0 = inputs.copy() # NOTE(ycho): custom transform + crop-aware collation. inputs = [self.transform(x) for x in inputs] inputs = collate_cropped_img(inputs) if (Schema.CROPPED_IMAGE not in inputs or len(inputs[Schema.CROPPED_IMAGE]) <= 0): return None # if we use the crop-aware collation, (batch_idx, instance_idx) indices = inputs[Schema.INDEX] dim, quat = self.model(inputs[Schema.CROPPED_IMAGE].to( self.device)) # NOTE(ycho): len(image) appropriated for batch_size batch_size = len(inputs0) outputs = [[] for _ in range(batch_size)] for i, (ii, s, q) in enumerate(zip(indices, dim, quat)): batch_index, instance_index = ii P = inputs0[batch_index][Schema.PROJECTION].reshape(4, 4) R = quaternion_to_matrix(q[None])[0] #R2 = inputs0[batch_index][ # Schema.ORIENTATION][instance_index].reshape( # 3, 3) #q2 = matrix_to_quaternion(R2[None])[0] #s2 = inputs0[batch_index][Schema.SCALE][instance_index] # Fix BOX_2D convention. box_i, box_j, box_h, box_w = inputs[Schema.BOX_2D][i] box_2d = th.as_tensor([box_i, box_j, box_i + box_h, box_j + box_w]) box_2d = 2.0 * (box_2d - 0.5) # Solve translation translation, _ = self.solve_translation({ # inputs from dataset Schema.PROJECTION: P, Schema.BOX_2D: box_2d, # inputs from network Schema.ORIENTATION: R, Schema.QUATERNION: q, Schema.SCALE: s # inputs from dataset (ground-truth) # Schema.ORIENTATION: R2, # Schema.QUATERNION: q2, # Schema.SCALE: s2 }) translation = th.as_tensor(translation, device=R.device) if translation[-1] > 0: translation *= -1.0 #print('tr-solve') #print(translation) #print('tr-gt') #print(inputs0[batch_index][Schema.TRANSLATION][instance_index]) # Convert to box-points box_out = self.box_points({ Schema.ORIENTATION: R.to(self.device), Schema.TRANSLATION: translation.to(self.device), Schema.SCALE: s.to(self.device), Schema.PROJECTION: P.to(self.device), Schema.INSTANCE_NUM: 1 }) entry = ( box_out[Schema.KEYPOINT_2D][0, ..., :2].detach().cpu().numpy(), box_out[Schema.KEYPOINT_3D][0].detach().cpu().numpy() ) outputs[batch_index].append(entry) return outputs
def __call__( self, inputs: Dict[Hashable, th.Tensor]) -> Dict[Hashable, th.Tensor]: proj_matrix = inputs[Schema.PROJECTION] # NOTE(ycho): BOX_2D = (i0, j0, i1, j1) in normalized coords (-1,1). box_2d = inputs[Schema.BOX_2D] # NOTE(ycho): outputs from network dimension = inputs[Schema.SCALE] if Schema.ORIENTATION in inputs: R = inputs[Schema.ORIENTATION].detach().cpu().numpy() elif Schema.QUATERNION in inputss: quaternion = inputs[Schema.QUATERNION] R = (quaternion_to_matrix( th.as_tensor(quaternion)).detach().cpu().numpy()) else: raise KeyError('Orientation information Not Found!') vertices = (self.points.cpu() * dimension.cpu()).detach().numpy() if True: # Reduce the number of permutations through geometric reasoning. fovs = 2.0 * np.arctan(1.0 / proj_matrix[[0, 1], [0, 1]]) with warnings.catch_warnings(): warnings.filterwarnings(action='ignore', category=LinAlgWarning) warnings.filterwarnings(action='ignore', category=OptimizeWarning) warnings.filterwarnings(action='ignore', category=np.VisibleDeprecationWarning) warnings.filterwarnings(action='ignore', category=RuntimeWarning) perms = compute_feasible_permutations(vertices @ R.T, fovs, self.debug_chull) perms = np.asarray(list(perms), dtype=np.int32) constraints = vertices[perms, :] else: constraints = list(itertools.permutations(vertices, 4)) # Initialize current best candidates. best_loc = None best_error = np.inf best_X = None # Loop through each possible constraint, hold on to the best guess K = proj_matrix.detach().cpu().numpy() # Create design matrices Ax=b for SVD. # K_ax is the axes of K repeated for each corresponding spatial axis. K_ax = K[(0, 1, 0, 1), :3] # TODO(ycho): use integer permutations directly and index into the array, # instead of creating (large) redundant copies for X in constraints: A = np.einsum('n,a->na', box_2d, K[2, :3]) - K_ax b = -np.einsum('na,ab,nb->n', A, R, X) # Solve here with least squares since overparameterized. # NOTE(ycho): `error` here indicates algebraic error; # it's generally preferable to use geometric error. loc, error, rank, s = np.linalg.lstsq(A, b, rcond=None) # Evaluate solution ... if self.recompute_error: # NOTE(ycho): evaluate error based on match with box. # FIXME(ycho): This probably results in much more expensive # evaluation. args = { Schema.ORIENTATION: th.as_tensor(R).detach().cpu(), Schema.TRANSLATION: th.as_tensor(loc).detach().cpu(), Schema.SCALE: th.as_tensor(dimension).detach().cpu(), Schema.PROJECTION: th.as_tensor(K).detach().cpu(), Schema.INSTANCE_NUM: 1 } out_points = self.box_points(args)[Schema.KEYPOINT_2D][..., :2] out_points = th.flip(out_points, dims=(-1, )) # XY->IJ out_points = 2.0 * (out_points - 0.5) # (0,1) -> (-1, +1) pmin = out_points.min(dim=-2).values.reshape(-1) pmax = out_points.max(dim=-2).values.reshape(-1) out_box_2d = th.cat([pmin, pmax]) error2 = th.norm(box_2d - out_box_2d.to(box_2d.device)) error = error2 # Update estimate with better alternative. if (error < best_error): best_loc = loc best_error = error best_X = X return best_loc, best_X
def main(): # data transform = Compose([ CropObject(CropObject.Settings()), Normalize(Normalize.Settings(keys=(Schema.CROPPED_IMAGE, ))) ]) _, test_loader = get_loaders(DatasetSettings(), th.device('cpu'), 1, transform=transform, collate_fn=collate_cropped_img) # model device = th.device('cuda') model = load_model() model = model.to(device) model.eval() # translation solver? solve_translation = SolveTranslation() box_points = BoxPoints2D(th.device('cpu'), Schema.KEYPOINT_2D) draw_bbox = DrawBoundingBoxFromKeypoints( DrawBoundingBoxFromKeypoints.Settings()) # eval for data in test_loader: # Skip occasional batches without any images. if Schema.CROPPED_IMAGE not in data: continue with th.no_grad(): # run inference crop_img = data[Schema.CROPPED_IMAGE].view(-1, 3, 224, 224) dim, quat = model(crop_img.to(device)) dim2, quat2 = data[Schema.SCALE], data[Schema.QUATERNION] logging.debug('D {} {}'.format(dim, dim2)) logging.debug('Q {} {}'.format(quat, quat2)) # trans = data[Schema.TRANSLATION] if False: dim = dim2 quat = quat2 R = quaternion_to_matrix(quat) R = quaternion_to_matrix(quat) input_image = data[Schema.IMAGE].detach().cpu() proj_matrix = (data[Schema.PROJECTION].detach().cpu().reshape( -1, 4, 4)) # Solve translations. translations = [] for i in range(len(proj_matrix)): box_i, box_j, box_h, box_w = data[Schema.BOX_2D][i] box_2d = th.as_tensor( [box_i, box_j, box_i + box_h, box_j + box_w]) box_2d = 2.0 * (box_2d - 0.5) args = { # inputs from dataset Schema.PROJECTION: proj_matrix[i], Schema.BOX_2D: box_2d, # inputs from network Schema.ORIENTATION: R[i], Schema.QUATERNION: quat[i], Schema.SCALE: dim[i] } # Solve translation translation, _ = solve_translation(args) translations.append(translation) translations = th.as_tensor(translations, dtype=th.float32) if True: print('num instances = {}'.format(len(translations))) pred_data = { Schema.IMAGE: data[Schema.IMAGE][0], Schema.ORIENTATION: R.cpu(), Schema.TRANSLATION: translations, Schema.SCALE: dim.cpu(), Schema.PROJECTION: proj_matrix[0], Schema.INSTANCE_NUM: len(proj_matrix), } pred_data = box_points(pred_data) pred_data = draw_bbox(pred_data) image_with_box = pred_data['img_w_bbox'] else: dimensions = dim.detach().cpu() quaternion = quat.detach().cpu() translations = translations.detach().cpu() #print(input_image.shape) #print(data[Schema.BOX_2D].shape) #print(proj_matrix.shape) #print(translations.shape) #print(dimensions.shape) #print(quaternion.shape) # draw box image_with_box = plot_regressed_3d_bbox( input_image, # keypoints_2d, # data[Schema.BOX_2D], data[Schema.KEYPOINT_2D], proj_matrix, dimensions, quaternion, translations) plt.clf() plt.imshow(image_with_box.permute(1, 2, 0)) plt.pause(0.1)
def run_optimization( self, silhouettes: torch.tensor, R: torch.tensor, T: torch.tensor, writer=None, camera_settings=None, step: int = 0, ): """ Function: Runs a batched optimization procedure that aims to minimize 3 reconstruction losses: -Silhouette IoU Loss: between input silhouettes and re-projected mesh -Mesh Edge consistency -Mesh Normal smoothing Mini Batching: If the number silhouettes is greater than the allowed batch size then a random set of images/poses is sampled for supervision at each step Returns: -Reconstruction losses: 3 reconstruction losses measured during optimization -Timing: -Iterations / second -Total time elapsed in seconds """ if len(R.shape) == 4: R = R.squeeze(1) T = T.squeeze(1) tf_smaller = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(self.params.img_size), transforms.ToTensor(), ]) images_gt = torch.stack([ tf_smaller(s.cpu()).to(self.device) for s in silhouettes ]).squeeze(1) if images_gt.max() > 1.0: images_gt = images_gt / 255.0 loop = tqdm_notebook(range(self.params.mesh_steps)) start_time = time.time() for i in loop: batch_indices = (random.choices(list(range(images_gt.shape[0])), k=self.params.mesh_batch_size) if images_gt.shape[0] > self.params.mesh_batch_size else list(range(images_gt.shape[0]))) batch_silhouettes = images_gt[batch_indices] batch_R, batch_T = R[batch_indices], T[batch_indices] # apply right transform on the Twv to adjust the coordinate system shift from EVIMO to PyTorch3D if self.params.is_real_data: init_R = quaternion_to_matrix(self.init_camera_R) batch_R = _broadcast_bmm(batch_R, init_R) batch_T = ( _broadcast_bmm(batch_T[:, None, :], init_R) + self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :] focal_length = (torch.tensor([ camera_settings[0, 0], camera_settings[1, 1] ])[None]).expand(batch_R.shape[0], 2) principle_point = (torch.tensor([ camera_settings[0, 2], camera_settings[1, 2] ])[None]).expand(batch_R.shape[0], 2) # FIXME: in this PyTorch3D version, the image_size in RasterizationSettings is (W, H), while in PerspectiveCameras is (H, W) # If the future pytorch3d change the format, please change the settings here # We hope PyTorch3D will solve this issue in the future batch_cameras = PerspectiveCameras( device=self.device, R=batch_R, T=batch_T, focal_length=focal_length, principal_point=principle_point, image_size=((self.params.img_size[1], self.params.img_size[0]), )) else: batch_cameras = PerspectiveCameras(device=self.device, R=batch_R, T=batch_T) mesh, laplacian_loss, flatten_loss = self.forward( self.params.mesh_batch_size) images_pred = self.renderer(mesh, device=self.device, cameras=batch_cameras)[..., -1] iou_loss = IOULoss().forward(batch_silhouettes, images_pred) loss = (iou_loss * self.params.lambda_iou + laplacian_loss * self.params.lambda_laplacian + flatten_loss * self.params.lambda_flatten) loop.set_description("Optimizing (loss %.4f)" % loss.data) self.losses["iou"].append(iou_loss * self.params.lambda_iou) self.losses["laplacian"].append(laplacian_loss * self.params.lambda_laplacian) self.losses["flatten"].append(flatten_loss * self.params.lambda_flatten) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if i % (self.params.mesh_show_step / 2) == 0 and self.params.mesh_log: logging.info( f'Iteration: {i} IOU Loss: {iou_loss.item()} Flatten Loss: {flatten_loss.item()} Laplacian Loss: {laplacian_loss.item()}' ) if i % self.params.mesh_show_step == 0 and self.params.im_show: # Write images image = images_pred.detach().cpu().numpy()[0] if writer: writer.append_data((255 * image).astype(np.uint8)) plt.imshow(images_pred.detach().cpu().numpy()[0]) plt.show() plt.imshow(batch_silhouettes.detach().cpu().numpy()[0]) plt.show() plot_pointcloud(mesh[0], 'Mesh deformed') logging.info( f'Pose of init camera: {self.init_camera_R.detach().cpu().numpy()}, {self.init_camera_t.detach().cpu().numpy()}' ) # Set the final optimized mesh as an internal variable self.final_mesh = mesh[0].clone() results = dict( silhouette_loss=self.losses["iou"] [-1].detach().cpu().numpy().tolist(), laplacian_loss=self.losses["laplacian"] [-1].detach().cpu().numpy().tolist(), flatten_loss=self.losses["flatten"] [-1].detach().cpu().numpy().tolist(), iterations_per_second=self.params.mesh_steps / (time.time() - start_time), total_time_s=time.time() - start_time, ) if self.is_real_data: self.init_pose_R = self.init_camera_R.detach().cpu().numpy() self.init_pose_t = self.init_camera_t.detach().cpu().numpy() torch.cuda.empty_cache() return results
def render_final_mesh(self, poses, mode: str, out_size: list, camera_settings=None) -> dict: """Renders the final mesh obtained through optimization Supports two modes: -predict: renders both silhouettes and flat shaded images -train: only renders silhouettes Returns: -dict of renders {'silhouettes': tensor, 'images': tensor} """ R, T = poses if len(R.shape) == 4: R = R.squeeze(1) T = T.squeeze(1) sil_renderer = silhouette_renderer(out_size, self.device) image_renderer = flat_renderer(out_size, self.device) # Create a silhouette projection of the mesh across all views all_silhouettes = [] all_images = [] for i in range(0, R.shape[0]): batch_R, batch_T = R[[i]], T[[i]] if self.params.is_real_data: init_R = quaternion_to_matrix(self.init_camera_R) batch_R = _broadcast_bmm(batch_R, init_R) batch_T = ( _broadcast_bmm(batch_T[:, None, :], init_R) + self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :] focal_length = torch.tensor( [camera_settings[0, 0], camera_settings[1, 1]])[None] principle_point = torch.tensor( [camera_settings[0, 2], camera_settings[1, 2]])[None] t_cameras = PerspectiveCameras( device=self.device, R=batch_R, T=batch_T, focal_length=focal_length, principal_point=principle_point, image_size=((self.params.img_size[1], self.params.img_size[0]), )) else: t_cameras = PerspectiveCameras(device=self.device, R=batch_R, T=batch_T) all_silhouettes.append( sil_renderer(self._final_mesh, device=self.device, cameras=t_cameras).detach().cpu()[..., -1]) if mode == "predict": all_images.append( torch.clamp( image_renderer(self._final_mesh, device=self.device, cameras=t_cameras), 0, 1, ).detach().cpu()[..., :3]) torch.cuda.empty_cache() renders = dict( silhouettes=torch.cat(all_silhouettes).unsqueeze(-1).permute( 0, 3, 1, 2), images=torch.cat(all_images) if all_images else [], ) return renders
points = points.bmm(r) points = points + t[:, None, :] return points if __name__ == "__main__": rand_pt = torch.randn(4, 1000, 3) rand_qt = random_qt(4, 0.5, 3) # qt -> rt q = rand_qt[:, :4] t = rand_qt[:, 4:, None] R = pt3d_T.quaternion_to_matrix(q) Rinv = pt3d_T.quaternion_to_matrix(pt3d_T.quaternion_invert(q)) Rti = torch.cat((Rinv, t), dim=2) Rt = torch.cat((R, t), dim=2) rot_qt = transform_points_qt(rand_pt, rand_qt) rot_Rt = transform_points_Rt(rand_pt, Rt) rot_Rti = transform_points_Rt(rand_pt, Rti) qt_Rt = (rot_qt - rot_Rt).norm(dim=2, p=2).mean() qt_Rti = (rot_qt - rot_Rti).norm(dim=2, p=2).mean() Rt_Rti = (rot_Rti - rot_Rt).norm(dim=2, p=2).mean() print(f"|| points ||: {rand_pt.norm(p=2,dim=2).mean()}") print(f"Diff Rt and qt: {qt_Rt:.4e}")
def forward(self, seq, msa=None, mask=None, msa_mask=None, extra_msa=None, extra_msa_mask=None, seq_index=None, seq_embed=None, msa_embed=None, templates_feats=None, templates_mask=None, templates_angles=None, embedds=None, recyclables=None, return_trunk=False, return_confidence=False, return_recyclables=False, return_aux_logits=False): assert not ( self.disable_token_embed and not exists(seq_embed) ), 'sequence embedding must be supplied if one has disabled token embedding' assert not ( self.disable_token_embed and not exists(msa_embed) ), 'msa embedding must be supplied if one has disabled token embedding' # if MSA is not passed in, just use the sequence itself if not exists(msa): msa = rearrange(seq, 'b n -> b () n') msa_mask = rearrange(mask, 'b n -> b () n') # assert on sequence length assert msa.shape[-1] == seq.shape[ -1], 'sequence length of MSA and primary sequence must be the same' # variables b, n, device = *seq.shape[:2], seq.device n_range = torch.arange(n, device=device) # unpack (AA_code, atom_pos) if isinstance(seq, (list, tuple)): seq, seq_pos = seq # embed main sequence x = self.token_emb(seq) if exists(seq_embed): x += seq_embed # mlm for MSAs if self.training and exists(msa): original_msa = msa msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool()) noised_msa, replaced_msa_mask = self.mlm.noise(msa, msa_mask) msa = noised_msa # embed multiple sequence alignment (msa) if exists(msa): m = self.token_emb(msa) if exists(msa_embed): m = m + msa_embed # add single representation to msa representation m = m + rearrange(x, 'b n d -> b () n d') # get msa_mask to all ones if none was passed msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool()) elif exists(embedds): m = self.embedd_project(embedds) # get msa_mask to all ones if none was passed msa_mask = default( msa_mask, lambda: torch.ones_like(embedds[..., -1]).bool()) else: raise Error('either MSA or embeds must be given') # derive pairwise representation x_left, x_right = self.to_pairwise_repr(x).chunk(2, dim=-1) x = rearrange(x_left, 'b i d -> b i () d') + rearrange( x_right, 'b j d-> b () j d') # create pair-wise residue embeds x_mask = rearrange(mask, 'b i -> b i ()') * rearrange( mask, 'b j -> b () j') if exists(mask) else None # add relative positional embedding seq_index = default(seq_index, lambda: torch.arange(n, device=device)) seq_rel_dist = rearrange(seq_index, 'i -> () i ()') - rearrange( seq_index, 'j -> () () j') seq_rel_dist = seq_rel_dist.clamp( -self.max_rel_dist, self.max_rel_dist) + self.max_rel_dist rel_pos_emb = self.pos_emb(seq_rel_dist) x = x + rel_pos_emb # add recyclables, if present if exists(recyclables): m[:, 0] = m[:, 0] + self.recycling_msa_norm( recyclables.single_msa_repr_row) x = x + self.recycling_pairwise_norm(recyclables.pairwise_repr) distances = torch.cdist(recyclables.coords, recyclables.coords, p=2) boundaries = torch.linspace(2, 20, steps=self.recycling_distance_buckets, device=device) discretized_distances = torch.bucketize(distances, boundaries[:-1]) distance_embed = self.recycling_distance_embed( discretized_distances) x = x + distance_embed # embed templates, if present if exists(templates_feats): _, num_templates, *_ = templates_feats.shape # embed template t = self.to_template_embed(templates_feats) t_mask_crossed = rearrange(templates_mask, 'b t i -> b t i ()') * rearrange( templates_mask, 'b t j -> b t () j') t = rearrange(t, 'b t ... -> (b t) ...') t_mask_crossed = rearrange(t_mask_crossed, 'b t ... -> (b t) ...') for _ in range(self.templates_embed_layers): t = self.template_pairwise_embedder(t, mask=t_mask_crossed) t = rearrange(t, '(b t) ... -> b t ...', t=num_templates) t_mask_crossed = rearrange(t_mask_crossed, '(b t) ... -> b t ...', t=num_templates) # template pos emb x_point = rearrange(x, 'b i j d -> (b i j) () d') t_point = rearrange(t, 'b t i j d -> (b i j) t d') x_mask_point = rearrange(x_mask, 'b i j -> (b i j) ()') t_mask_point = rearrange(t_mask_crossed, 'b t i j -> (b i j) t') template_pooled = self.template_pointwise_attn( x_point, context=t_point, mask=x_mask_point, context_mask=t_mask_point) template_pooled_mask = rearrange( t_mask_point.sum(dim=-1) > 0, 'b -> b () ()') template_pooled = template_pooled * template_pooled_mask template_pooled = rearrange(template_pooled, '(b i j) () d -> b i j d', i=n, j=n) x = x + template_pooled # add template angle features to MSAs by passing through MLP and then concat if exists(templates_angles): t_angle_feats = self.template_angle_mlp(templates_angles) m = torch.cat((m, t_angle_feats), dim=1) msa_mask = torch.cat((msa_mask, templates_mask), dim=1) # embed extra msa, if present if exists(extra_msa): extra_m = self.token_emb(msa) extra_msa_mask = default(extra_msa_mask, torch.ones_like(extra_m).bool()) x, extra_m = self.extra_msa_evoformer(x, extra_m, mask=x_mask, msa_mask=extra_msa_mask) # trunk x, m = self.net(x, m, mask=x_mask, msa_mask=msa_mask) # ready output container ret = ReturnValues() # calculate theta and phi before symmetrization if self.predict_angles: ret.theta_logits = self.to_prob_theta(x) ret.phi_logits = self.to_prob_phi(x) # embeds to distogram trunk_embeds = (x + rearrange(x, 'b i j d -> b j i d')) * 0.5 # symmetrize distance_pred = self.to_distogram_logits(trunk_embeds) ret.distance = distance_pred # calculate mlm loss, if training msa_mlm_loss = None if self.training and exists(msa): num_msa = original_msa.shape[1] msa_mlm_loss = self.mlm(m[:, :num_msa], original_msa, replaced_msa_mask) # determine angles, if specified if self.predict_angles: omega_input = trunk_embeds if self.symmetrize_omega else x ret.omega_logits = self.to_prob_omega(omega_input) if not self.predict_coords or return_trunk: return ret # derive single and pairwise embeddings for structural refinement single_msa_repr_row = m[:, 0] single_repr = self.msa_to_single_repr_dim(single_msa_repr_row) pairwise_repr = self.trunk_to_pairwise_repr_dim(x) # prepare float32 precision for equivariance original_dtype = single_repr.dtype single_repr, pairwise_repr = map(lambda t: t.float(), (single_repr, pairwise_repr)) # iterative refinement with equivariant transformer in high precision with torch_default_dtype(torch.float32): quaternions = torch.tensor([1., 0., 0., 0.], device=device) # initial rotations quaternions = repeat(quaternions, 'd -> b n d', b=b, n=n) translations = torch.zeros((b, n, 3), device=device) # go through the layers and apply invariant point attention and feedforward for i in range(self.structure_module_depth): is_last = i == (self.structure_module_depth - 1) # the detach comes from # https://github.com/deepmind/alphafold/blob/0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408/alphafold/model/folding.py#L383 rotations = quaternion_to_matrix(quaternions) if not is_last: rotations = rotations.detach() single_repr = self.ipa_block(single_repr, mask=mask, pairwise_repr=pairwise_repr, rotations=rotations, translations=translations) # update quaternion and translation quaternion_update, translation_update = self.to_quaternion_update( single_repr).chunk(2, dim=-1) quaternion_update = F.pad(quaternion_update, (1, 0), value=1.) quaternions = quaternion_multiply(quaternions, quaternion_update) translations = translations + einsum( 'b n c, b n c r -> b n r', translation_update, rotations) points_local = self.to_points(single_repr) rotations = quaternion_to_matrix(quaternions) coords = einsum('b n c, b n c d -> b n d', points_local, rotations) + translations coords.type(original_dtype) if return_recyclables: coords, single_msa_repr_row, pairwise_repr = map( torch.detach, (coords, single_msa_repr_row, pairwise_repr)) ret.recyclables = Recyclables(coords, single_msa_repr_row, pairwise_repr) if return_aux_logits: return coords, ret if return_confidence: return coords, self.lddt_linear(single_repr.float()) return coords