def voxel_transform(item, grid_config, rot_mat=None, center_fn=vox.get_center, random_seed=None, structure_keys=['atoms']): """Transform for converting dataframes to voxelized grids compatible with 3D CNN, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`. Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` argument. :param item: Dataset item to transform :type item: dict :param grid_config: Config parameters for grid. Should contain the following keys: `element_mapping`, dictionary mapping from element to 1-hot index; `radius`, radius of grid to generate in Angstroms (half of side length); `resolution`, voxel size in Angstroms; `num_directions`, number of directions for data augmentation (required if ``rot_mat``=None); `num_rolls`, number of rolls, or rotations, for data augmentation (required if ``rot_mat``=None) :type grid_config: :class:`dotdict <atom3d.util.voxelize.dotdict>` :param rot_mat: Rotation matrix (3x3) to apply to structure coordinates. If None (default), apply randomly sampled rotation according to parameters specified by ``grid_config.num_directions`` and ``grid_config.num_rolls`` :type rot_mat: np.array :param center_fn: Arbitrary function for calculating the center of the voxelized grid (x,y,z coordinates) from a structure dataframe, defaults to vox.get_center :type center_fn: f(df -> array), optional :param random_seed: random seed for grid rotation, defaults to None :type random_seed: int, optional :return: Transformed Dataset item :rtype: dict """ for key in structure_keys: df = item[key] center = center_fn(df) if rot_mat is None: rot_mat = vox.gen_rot_matrix(grid_config, random_seed=random_seed) grid = vox.get_grid( df, center, config=grid_config, rot_mat=rot_mat) item[key] = grid return item
def _feature(struct, center): # Generate random rotation matrix rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed) # Transform into voxel grids and rotate grid = get_grid(struct, center, config=self.grid_config, rot_mat=rot_mat) # Last dimension is atom channel, so we need to move it to the front # per pytroch style grid = np.moveaxis(grid, -1, 0) return grid
def _voxelize(self, atoms): # Use center of protein as subgrid center pos = atoms[['x', 'y', 'z']].astype(np.float32) center = get_center(pos) # Generate random rotation matrix rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed) # Transform protein/ligand into voxel grids and rotate grid = get_grid(atoms, center, config=self.grid_config, rot_mat=rot_mat) # Last dimension is atom channel, so we need to move it to the front # per pytroch style grid = np.moveaxis(grid, -1, 0) return grid
def _voxelize(self, atoms, mut_chain, mut_res, is_mutated): # Either center at CA of the mutated residue or at structure center center = self._get_voxel_center(atoms, mut_chain, mut_res) # Generate random rotation matrix rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed) # Transform into voxel grids and rotate grid = get_grid(atoms, center, config=self.grid_config, rot_mat=rot_mat) if self.add_flag: # Add original (0) or mutated (1) flag flag = np.full(grid.shape[:-1] + (1, ), is_mutated) grid = np.concatenate([grid, flag], axis=3) # Last dimension is atom channel, so we need to move it to the front # per pytroch style grid = np.moveaxis(grid, -1, 0) return grid
def _voxelize(self, atoms, is_active): # Use center of ligand as subgrid center ligand_pos = atoms[atoms.chain == 'L'][['x', 'y', 'z']].astype(np.float32) ligand_center = get_center(ligand_pos) # Generate random rotation matrix rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed) # Transform protein/ligand into voxel grids and rotate grid = get_grid(atoms, ligand_center, config=self.grid_config, rot_mat=rot_mat) if self.add_flag: # Add inactive (0) or active (1) flag flag = np.full(grid.shape[:-1] + (1, ), is_active) grid = np.concatenate([grid, flag], axis=3) # Last dimension is atom channel, so we need to move it to the front # per pytroch style grid = np.moveaxis(grid, -1, 0) return grid