Esempio n. 1
    def load(path: str, as_numpy: bool = False) -> Union[np.ndarray, \
            List[np.ndarray], TFDataset, TorchDataset]:
        """Load the dataset.

        Assumes that each record's np.ndarrays are saved under default
        names `arr_0`, ..., `arr_n` and its associated shapes as
        `shape_0`, ..., `shape_n`.

        path: str
            TFRecords file which contains the dataset.

        as_numpy: bool, default=False
            If True, loads the dataset as a list of np.ndarray's (in
            its original form). If False, loads the dataset in
            P.backend()'s specified format.

        data: np.ndarray,, or
            The dataset (either in its original shape/format or in
            P.backend()'s specified format).
        # Parse serialized records into correctly shaped tensors/ndarray's
        if as_numpy:
            dataset_dict = {}
            axis = 0 if P.data_format() == "batch_first" else -1
            for serialized in tf.python_io.tf_record_iterator(path):
                example = tf.train.Example()
                # Save array shapes to ensure respective ndarrays get reshaped properly
                shapes = {name: list(example.features.feature[name].int64_list.value)
                          for name in example.features.feature.keys() if name.startswith("shape")}
                for name, tf_feature in example.features.feature.items():
                    if name.startswith("arr"):
                        parsed_data = np.frombuffer(tf_feature.bytes_list.value[0], np.float)
                        newshape = shapes.get("shape_{}".format(name.split("_")[-1]))
                        newshape = [1] + newshape if P.data_format() == \
                            "batch_first" else newshape + [1]
                        reshaped = np.reshape(parsed_data, newshape=newshape)
                        # Concatenate individual examples together into one ndarray.
                        if name not in dataset_dict.keys():
                            dataset_dict[name] = reshaped
                            dataset_dict[name] = np.concatenate((dataset_dict.get(name), \
                                reshaped), axis=axis)
            # Extract np.ndarray's from dict and return in its original form
            data = dataset_dict.values()
            return data[0] if len(data) == 1 else data
        return TorchTFRecordsDataset(path) if P.backend() == "torch" \
            else TFRecordsDataset(path)
Esempio n. 2
    def save(data: Union[np.ndarray, List[np.ndarray]], path: str,
             compress: bool = True) -> None:
        """Save data to .npz file.

        data: np.ndarray or list of np.ndarray
            The ndarray's to serialize.

        path: str
            Output npz file.

        compress: bool, default=True
            If True, uses gzip to compress the file. If False, no
            compression is performed.
        if isinstance(data, np.ndarray):
            data = [data]

        # Convert np.ndarray's dtype to their 32 counterparts
        data = [arr.astype(arr.dtype.kind) for arr in data]

        # Check for same num of examples in the multiple ndarray's
        axis = 0 if P.data_format() == "batch_first" else -1
        assert all(data[0].shape[axis] == arr.shape[axis] for arr in data), \
            (f"Unequal num of examples in {P.data_format()} (axis={axis}): "
             f"{[arr.shape for arr in data]} - is the data format correct?")

        dataset_dict = {f"arr_{idx}": arr for idx, arr in enumerate(data)}
        dataset_dict["__keys__"] = [f'arr_{idx}' for idx in range(len(data))]
        dataset_dict["saved_axis"] = axis

        # Save each ndarray (denoted by its key) seperately
        np.savez_compressed(path, **dataset_dict) if compress \
            else np.savez(path, **dataset_dict)
Esempio n. 3
    def load(path: str, as_numpy: bool = False) -> Union[np.ndarray, \
            List[np.ndarray], TFDataset, TorchDataset]:
        """Load the dataset.

        path: str
            Npz file which contains dataset.

        as_numpy: bool, default=False
            If True, loads the dataset as a list of np.ndarray's (in
            its original form). If False, loads the dataset in
            P.backend()'s specified format.

        data: np.ndarray,, or
            The dataset (either in its original shape/format or in
            P.backend()'s specified format).
        if as_numpy:
            out_axis = 0 if P.data_format() == "batch_first" else -1
            with np.load(path, allow_pickle=False) as npzfile:
                saved_axis = int(npzfile["saved_axis"])
                data = [np.moveaxis(npzfile[key], saved_axis, out_axis)
                        for key in npzfile["__keys__"]]
            return data[0] if len(data) == 1 else data
        return TorchNumpyDataset(path) if P.backend() == "torch" \
            else TensorflowNumpyDataset(path)
Esempio n. 4
    def load(path: str, as_numpy: bool = False) -> Union[np.ndarray, \
            List[np.ndarray], TFDataset, TorchDataset]:
        """Load the dataset.

        path: str
            HDF5 file which contains dataset.

        as_numpy: bool, default=False
            If True, loads the dataset as a list of np.ndarray's (in
            its original form). If False, loads the dataset in
            P.backend()'s specified format.

        data: np.ndarray,, or
            The dataset (either in its original shape/format or in
            P.backend()'s specified format).
        if as_numpy:
            with h5py.File(path, "r") as h5file:
                saved_axis = h5file.attrs.get("saved_axis")
                out_axis = 0 if P.data_format() == "batch_first" else -1
                data = [np.moveaxis(arr[:], saved_axis, out_axis)
                        for arr in h5file.values()]
            return data[0] if len(data) == 1 else data
        return TorchHDF5Dataset(path) if P.backend() == "torch" \
            else TensorflowHDF5Dataset(path)
Esempio n. 5
    def save(data: Union[np.ndarray, List[np.ndarray]], path: str,
             compress: bool = True) -> None:
        """Save data to .h5 (or .hdf5) file.

        data: np.ndarray or list of np.ndarray
            The ndarray's to serialize.

        path: str
            Output HDF5 file.

        compress: bool, default=True
            If True, uses gzip to compress the file. If False, no
            compression is performed.
        if isinstance(data, np.ndarray):
            data = [data]

        # Convert np.ndarray's dtype to their 32 counterparts
        data = [arr.astype(arr.dtype.kind) for arr in data]

        # Check for same num of examples in the multiple ndarray's
        axis = 0 if P.data_format() == "batch_first" else -1
        assert all(data[0].shape[axis] == arr.shape[axis] for arr in data), \
            (f"Unequal num of examples in {P.data_format()} (axis={axis}): "
             f"{[arr.shape for arr in data]} - is the data format correct?")

        with h5py.File(path, "w") as h5file:
            # Save ndarrays and axis where num_samples is represented
            h5file.attrs.create("saved_axis", axis)
            for idx, arr in enumerate(data):
                h5file.create_dataset(name=f"arr_{idx}", data=arr, chunks=True,
                                      compression="gzip" if compress else None)
Esempio n. 6
    def load(path: str, as_numpy: bool = False) -> Union[np.ndarray, \
            List[np.ndarray], TFDataset, TorchDataset]:
        """Load the dataset.

        path: str
            LMDB file which contains dataset.

        as_numpy: bool, default=False
            If True, loads the dataset as a list of np.ndarray's (in
            its original form). If False, loads the dataset in
            P.backend()'s specified format.

        data: np.ndarray,, or
            The dataset (either in its original shape/format or in
            P.backend()'s specified format).
        # Check whether directory or full filename is provided. If dir, check
        # for "data.mdb" file within dir.
        isdir = os.path.isdir(path)
        if isdir:
            default_path = os.path.join(path, "data.mdb")
            assert os.path.isfile(default_path), "LMDB default file {} does " \
                "not exist!".format(default_path)
            assert os.path.isfile(path), "LMDB file {} does not exist!".format(path)

        if as_numpy:
            out_axis = 0 if P.data_format() == "batch_first" else -1
            db =, subdir=isdir, readonly=True)
            with db.begin() as txn, txn.cursor() as cursor:
                saved_axis = pkl.loads(cursor.get(b"saved_axis"))
                data = [np.moveaxis(pkl.loads(cursor.get(key)), saved_axis, out_axis)
                        for key in pkl.loads(cursor.get(b"__keys__"))]
            return data[0] if len(data) == 1 else data
        return TorchLMDBDataset(path) if P.backend() == "torch" \
            else TensorflowLMDBDataset(path)
Esempio n. 7
    def save(data: Union[np.ndarray, List[np.ndarray]], path: str,
             save_index: bool = True) -> None:
        """Save data to .tfrecords file.

        Saves each np.ndarray under default names `arr_0`, ..., `arr_n`
        and its associated shapes as `shape_0`, ..., `shape_n`.

        NOTE: TFRecords flatten each ndarray before saving them as a
        bytes_list feature (thus losing array's shape metadata). To
        combat this, we save each ndarray's shape dims (as int64_list
        feature) and reshape them accordingly when loading.

        data: np.ndarray or list of np.ndarray
            The ndarray's to serialize.

        path: str
            Output TFRecords file.

        save_index: bool, default=True
            If True, saves an index of records. If False, no index is
        def _bytes_feature(value: Union[str, bytes]) -> tf.train.Feature:
            """Returns a bytes_list from a string / byte."""
            if isinstance(value, type(tf.constant(0))):
                value = value.numpy() # BytesList won't unpack string from an EagerTensor.
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

        def _float_feature(value: float) -> tf.train.Feature:
            """Returns a float_list from a float / double."""
            if not isinstance(value, (list, np.ndarray)):
                value = [value] # FloatList won't unpack unless it is an list/np.array.
            return tf.train.Feature(float_list=tf.train.FloatList(value=value))

        def _int64_feature(value: Union[bool, int]) -> tf.train.Feature:
            """Returns an int64_list from a bool / enum / int / uint."""
            if not isinstance(value, (list, np.ndarray)):
                value = [value] # Int64List won't unpack, unless it is an list/np.array.
            return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

        def _serialize(example: Dict[str, Dict[str, Any]]) -> str:
            """Serialize an example within the dataset."""
            dset_item = {}
            for feature in example.keys():
                dset_item[feature] = example[feature]["_type"](example[feature]["data"])
                features = tf.train.Features(feature=dset_item)
                example_proto = tf.train.Example(features=features)
            return example_proto.SerializeToString()

        def _create_idx(tfrecord_file: path, index_file: path) -> None:
            """Create index of TFRecords file. See

            The rows (contained within the file) indicates the num of
            examples in the dataset. The last column indicates how many
            bytes of storage each example takes.
            infile = open(tfrecord_file, "rb")
            outfile = open(index_file, "w")

            while True:
                current = infile.tell()
                    byte_len =
                    if len(byte_len) == 0:
                    proto_len = struct.unpack("q", byte_len)[0]
                    outfile.write(str(current) + " " + str(infile.tell() - current) + "\n")
                except Exception:
                    print("Failed to parse TFRecord.")


        if isinstance(data, np.ndarray):
            data = [data]

        # Convert np.ndarray's dtype to their 32 counterparts
        data = [arr.astype(arr.dtype.kind) for arr in data]

        # Check for same num of examples in the multiple ndarray's
        axis = 0 if P.data_format() == "batch_first" else -1
        assert all(data[0].shape[axis] == arr.shape[axis] for arr in data), \
            (f"Unequal num of examples in {P.data_format()} (axis={axis}): "
             f"{[arr.shape for arr in data]} - is the data format correct?")

        # Add shapes of each array in the dataset (for a single example). Hack
        # to allow serialized data to be reshaped properly when loaded.
        shapes = {f"shape_{idx}": np.delete(arr.shape, axis)
                  for idx, arr in enumerate(data)}
        dataset = {f"arr_{idx}": arr for idx, arr in enumerate(data)}

        # Write serialized example(s) into the dataset
        n_examples = data[0].shape[axis]
        with as writer:
            for row in tqdm(range(n_examples), total=n_examples):
                # NOTE: tobytes() flattens an ndarray. We have to flatten it
                # because tf _bytes_feature() only takes in bytes. To combat
                # this, we save each ndarray's shape as well (see above).
                example = {}
                for key, arr in dataset.items():
                    # Save metadata about the array (aka shape) as int64 feature
                    if key.startswith("shape"):
                        example[key] = {"data": arr, "_type": _int64_feature}
                        example[key] = {"data": arr[row].tobytes() if P.data_format() \
                            == "batch_first" else arr[..., row].tobytes(),
                                        "_type": _bytes_feature}

        # Write index of the examples
        # NOTE: It's recommended to create an index file for each TFRecord file.
        # Index file must be provided when using multiple workers, otherwise the
        # loader may return duplicate records.
        if save_index:
            _create_idx(tfrecord_file=path, index_file=f"{path}_idx")
Esempio n. 8
    def save(data: Union[np.ndarray, List[np.ndarray]], path: str,
             write_frequency: int = 1) -> None:
        """Save data to .lmdb/.mdb file.

        data: np.ndarray or list of np.ndarray
            The ndarray's to serialize.

        path: str
            Output LMDB directory or file.

        write_frequence: int, default=1
            The frequency to write back data to disk. Smaller value
            reduces memory usage, at the cost of performance.
        if isinstance(data, np.ndarray):
            data = [data]

        # Convert np.ndarray's dtype to their 32 counterparts
        data = [arr.astype(arr.dtype.kind) for arr in data]

        # Check for same num of examples in the multiple ndarray's
        axis = 0 if P.data_format() == "batch_first" else -1
        assert all(data[0].shape[axis] == arr.shape[axis] for arr in data), \
            (f"Unequal num of examples in {P.data_format()} (axis={axis}): "
             f"{[arr.shape for arr in data]} - is the data format correct?")

        # Check whether directory or full filename is provided. If dir, check
        # for "data.mdb" file within dir.
        isdir = os.path.isdir(path)
        if isdir:
            assert not os.path.isfile(os.path.join(path, "data.mdb")), \
                "LMDB file {} exists!".format(os.path.join(path, "data.mdb"))
            assert not os.path.isfile(path), "LMDB file {} exists!".format(path)

        # It's OK to use super large map_size on Linux, but not on other platforms
        # See:
        map_size = 1099511627776 * 2 if platform.system() == 'Linux' else 128 * 10**6
        db =, subdir=isdir, map_size=map_size, readonly=False,
                       meminit=False, map_async=True) # need sync() at the end

        # Put data into lmdb, and doubling the size if full.
        # Ref:
        def put_or_grow(txn, key, value):
                txn.put(key, value)
                return txn
            except lmdb.MapFullError:
            curr_size =['map_size']
            new_size = curr_size * 2
            print(f"Doubling LMDB map_size to {new_size / 10**9:.2f} GB.")
            txn = db.begin(write=True)
            txn = put_or_grow(txn, key, value)
            return txn

        # NOTE: LMDB transaction is not exception-safe (even though it has a
        # context manager interface).
        txn = db.begin(write=True)
        for idx, arr in enumerate(data):
            key = f"arr_{idx}".encode('ascii')
            txn = put_or_grow(txn, key=key, value=pkl.dumps(arr, protocol=-1))
            # NOTE: If we do not commit some ndarrays before the db grows,
            # those do not get saved. As such, for robustness, we choose
            # write_frequency=1 (at the cost of performance).
            if (idx + 1) % write_frequency == 0:
                txn = db.begin(write=True)
        txn.commit() # commit all remaining serialized ndarrays

        # Add all keys used (in this case it is just the array names) and axis
        # where num_samples is represented
        keys = [f'arr_{idx}'.encode('ascii') for idx in range(len(data))]
        with db.begin(write=True) as txn:
            txn = put_or_grow(txn, key=b'__keys__', value=pkl.dumps(keys, protocol=-1))
            txn = put_or_grow(txn, key=b'saved_axis', value=pkl.dumps(axis, protocol=-1))

        print("Flushing database ...")
Esempio n. 9
    def parse(self,
              df: pd.DataFrame,
              target_index: Optional[List[int]] = None,
              return_is_successful: bool = True) -> Dict[str, Any]:
        """Parse dataframe using the preprocessor given.

        df: pd.DataFrame
            DataFrame to be parsed.

        target_index: list of int or None, optional, default=None
            Indicies to extract. If None, then all examples (in the dataset) 
            are parsed. Allows for easier batching.

        return_is_successful: bool, optional, default=True
            If True, boolean list (representing whether parsing of the 
            sequence has succeeded or not) is returned in the key 
            'is_successful'. If False, `None` is returned instead.
        features = None
        is_successful_list = []
        pp = self.preprocessor
        mutator = self.mutator
        processed_as = 'sequence' if self.process_as_seq else 'SMILES'

        if target_index is not None:
            df = df.iloc[target_index]

        data_index = df.columns.get_loc(self.data_col)
        pdb_index = df.columns.get_loc(
            self.pdb_col) if self.pdb_col is not None else None
        pos_index = df.columns.get_loc(
            self.pos_col) if self.pos_col is not None else None
        labels_index = [] if self.labels is None else [
            df.columns.get_loc(l) for l in self.labels

        fail_count = 0
        success_count = 0
        total_count = df.shape[0]
        for row in tqdm(df.itertuples(index=False), total=total_count):
            data: Optional[Union[str, List[str]]] = row[data_index]
            pdbid = row[pdb_index] if pdb_index is not None else None
            positions = row[pos_index] if pos_index is not None else None
            labels = [row[i] for i in labels_index]

                # Check for valid data input
                if data is None:
                    raise TypeError("Invalid type: {}. Should be str or list " \
                        "of str.".format(type(data).__name__))
                elif len(data) == 0:
                    # Raise error for now, if empty list or str is passed in.
                    # TODO: Change how each type (molecule or sequence) feature
                    # processing handles empty data. If mol.GetNumAtoms() == 0
                    # or len(seq) == 0, then a respective FeatureExtractionError
                    # should be raised.
                    raise ValueError("Cannot process empty data.")

                # SMILES parsing
                if not self.process_as_seq:
                    if mutator is not None:
                            "SMILES string '{}' cannot be mutated.".format(

                    # SMILES string can only be processed as rdkit.Mol instance.
                    mol = rdmolfiles.MolFromSmiles(data, sanitize=True)
                    if mol is None:
                        raise TypeError("Invalid type: {}. Should be " \

                    # Compute features if its a proper molecule
                    if isinstance(pp, MolPreprocessor):
                        input_feats = pp.get_input_feats(mol)
                        valid_preprocessors = [
                            for pp in preprocess_method_dict.values()
                            if isinstance(pp(), MolPreprocessor)
                        raise ValueError("{} cannot compute features for SMILES-based input " \
                            "'{}'. Choose a valid SMILES-based preprocessor: {}.".format( \
                            type(pp).__name__, data, valid_preprocessors))
                    # Sequence-based parsing
                    if mutator is not None:
                        if pdbid is None:
                            raise ValueError(
                                "PDB ID not specified. Unable to mutate residue."

                        if positions is None:
                            raise ValueError("Positions not specified. PDBMutator needs " \
                                "residue positions to mutate residues at defined locations.")
                            # Raise error for now, as lengths of positions and seqs need to match
                            # to work with the current implementation of mutator.
                            # TODO: Change when implementation of mutator changes.
                            # NOTE: Should we assume that if the len(positions) < len(data), then
                            # the user wants to modify those positions in the sequence?
                            if len(data) != len(positions):
                                raise ValueError("Length of input (N={}) is not the same as number " \
                                    "of positions (N={}) to modify. Did you pass in the full " \
                                    "sequence? Currently, mutations can only be performed with " \
                                    "information about which residue position(s) to modify and the " \
                                    "replacement residue(s) at those positions. If you want to " \
                                    "process only the input sequence (without any mutations), " \
                                    "set mutator=None.".format(len(data), len(positions)))
                            # Mutate residues (to primary or tertiary) based off mutator instance
                            replace_with = {
                                resid: data[i]
                                for i, resid in enumerate(positions)
                            data = mutator.mutate(pdbid,

                        # Obtain features based on which preprocessor is used
                        if isinstance(pp,
                            input_feats = pp.get_input_feats(data)
                            raise NotImplementedError
                        # Since it is not mutated, the data can now ONLY be a sequence
                        # (since 3D representation cannot be within a single column in a df)
                        if isinstance(pp, SequencePreprocessor):
                            input_feats = pp.get_input_feats(data)
                            valid_preprocessors = [
                                for pp in preprocess_method_dict.values()
                                if isinstance(pp(), SequencePreprocessor)
                            raise ValueError("{} cannot compute features for sequence-based input " \
                                "'{}'. Either mutate data (by passing in PDBMutator instance) to " \
                                "'tertiary' structure or choose a valid sequence-based preprocessor: " \
                                "{}.".format(type(pp).__name__, data, valid_preprocessors))
            except Exception as e:
                # If for some reason the data cannot be parsed properly, skip
                print('Error while parsing `{}` as {}, type: {}, {}'.format(\
                    data, processed_as, type(e).__name__, e.args))
                fail_count += 1
                if return_is_successful:

            # Initialize features: list of lists
            if features is None:
                num_feats = len(input_feats) if isinstance(input_feats,
                                                           tuple) else 1
                if self.labels is not None:
                    num_feats += 1
                features = [[] for _ in range(num_feats)]

            # Append computed features to respective cols
            if isinstance(input_feats, tuple):
                for i in range(len(input_feats)):

            # Add label values as last column, if provided
            if self.labels is not None:
                features[len(features) - 1].append(labels)

            success_count += 1
            if return_is_successful:

        print('Preprocess finished. FAIL {}, SUCCESS {}, TOTAL {}'.format(\
            fail_count, success_count, total_count))

        # Compile feature(s) into individual np.ndarray(s), padding each to max
        # dims, if necessary. NOTE: The num of examples in the dataset depends
        # on the data_format specified (represented by first/last channel).
        all_feats = [broadcast_array(feature)
                     for feature in features] if features else []
        if P.data_format() == "batch_last":
            all_feats = [np.moveaxis(feat, 0, -1) for feat in all_feats]
        is_successful = np.array(
            is_successful_list) if return_is_successful else None
        return {"dataset": all_feats, "is_successful": is_successful}