コード例 #1
0
ファイル: serialize.py プロジェクト: ai-med/almgig
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('graphs', metavar='PICKLE_FILE',
                        help='Path to pickled graphs.')
    parser.add_argument('--data', choices=['gdb9', 'zinc'], required=True,
                        help='Data to serialize.')
    parser.add_argument('--reward_type', type=RewardType.from_string,
                        metavar=RewardType.metavar(),
                        required=True)
    parser.add_argument('--norm_file',
                        help='Path to file to standardize penalized logP score.')

    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    filename = Path(args.graphs)
    data, max_nodes = read_graph_data(filename, get_dataset(args.data).MAX_NODES)

    validator = Graph2MolValidator()

    conv = get_decoder(args.data)
    mol_metrics = GraphMolecularMetrics(conv, args.reward_type, args.norm_file)

    outfile = filename.with_suffix('.mdb')
    LOG.info("Saving to %s", outfile)
    ds = create_dataflow(data,
                         max_nodes=max_nodes,
                         metrics_fn=mol_metrics.get_reward_metrics,
                         validator=validator,
                         shuffle=False)
    LMDBSerializer.save(ds, str(outfile))

    LOG.warning('%d erroneous molecules', validator.n_errors)
コード例 #2
0
def serialize_to_lmdb(dataset, hparams, lmdb_path):
    if os.path.isfile(lmdb_path):
        print("lmdb file ({}) exists!".format(lmdb_path))
    else:
        df = DataFromList(dataset, shuffle=False)
        df = MapData(df, lambda data: map_func(data, hparams))
        print("Creating lmdb cache...")
        LMDBSerializer.save(df, lmdb_path)
コード例 #3
0
def dump_imdb(ds, output_path, parallel=None):
    """
    Create a Single-File LMDB from raw images.
    """
    if parallel is None:
        parallel = min(40,
                       multiprocessing.cpu_count() //
                       2)  # assuming hyperthreading

    def mapf(dp):
        fname, label = dp
        with open(fname, 'rb') as f:
            bytes = f.read()
        bytes = np.asarray(bytearray(bytes), dtype='uint8')
        return bytes, label

    ds = MultiThreadMapData(ds, 1, mapf, buffer_size=2000, strict=True)
    ds = MultiProcessRunnerZMQ(ds, num_proc=parallel)

    LMDBSerializer.save(ds, output_path)
コード例 #4
0
def get_tp_loader(data_dir, name, batch_size, parallel=None):
    isTrain = name == 'train'
    augmentors = get_tp_augmentor(isTrain)

    if data_dir.endswith('lmdb'):
        # 500000[70:87:20, 1.95it/s]
        data_dir = os.path.join(data_dir, 'ILSVRC-%s.lmdb' % name)
        ds = LMDBSerializer.load(data_dir, shuffle=False)
        ds = get_sequential_loader(ds, isTrain, batch_size, augmentors,
                                   parallel)
    else:
        # 500000[27:11:03, 5.11it/s]
        if isTrain:
            ds = dataset.ILSVRC12(data_dir, name, shuffle=True)
        else:
            ds = dataset.ILSVRC12Files(data_dir, name, shuffle=False)
        ds = get_random_loader(ds, isTrain, batch_size, augmentors, parallel)
    return ds
コード例 #5
0
def get_infer_iterator(dataset, hparams, lmdb_path):

    serialize_to_lmdb(dataset, hparams, lmdb_path)

    batch_size = hparams.infer_batch_size
    num_gpu = hparams.num_gpu

    df = LMDBSerializer.load(lmdb_path, shuffle=False)

    batched_df = BatchData(df, batch_size=batch_size, remainder=False)
    splitted_df = MapData(
        batched_df,
        lambda x: [np.array_split(x[idx], num_gpu) for idx in range(len(x))])
    prefetched_df = PrefetchDataZMQ(splitted_df,
                                    nr_proc=1,
                                    hwm=batch_size * 10)

    return prefetched_df
コード例 #6
0
def get_iterator(hparams,
                 dataset,
                 lmdb_path,
                 shuffle=True,
                 drop_remainder=True,
                 nr_proc=4):

    serialize_to_lmdb(hparams, dataset, lmdb_path)

    batch_size = hparams.batch_size
    num_gpu = hparams.num_gpu
    df = LMDBSerializer.load(lmdb_path, shuffle=shuffle)

    batched_df = BatchData(df,
                           batch_size=batch_size,
                           remainder=not drop_remainder)
    splitted_df = MapData(
        batched_df,
        lambda x: [np.array_split(x[idx], num_gpu) for idx in range(len(x))])
    prefetched_df = PrefetchDataZMQ(splitted_df,
                                    nr_proc=nr_proc,
                                    hwm=batch_size * 10)

    return prefetched_df
コード例 #7
0
ファイル: prepare_data.py プロジェクト: askerlee/featlens
                jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
                yield [jpeg, label]
    imagenet_path = os.environ['IMAGENET']
        
    for name in ['train', 'val']: # ['test']
        ds0 = BinaryILSVRC12(imagenet_path, name)
        ds1 = MultiProcessRunnerZMQ(ds0, nr_proc=1)
        # dftools.dump_dataflow_to_lmdb(ds1, os.path.join(imagenet_path,'ILSVRC-%s.lmdb'%name))
        if args.n == 1:
            paths = [os.path.join(imagenet_path,'ILSVRC-%s.lmdb'%name)]
        else:
            paths = [ os.path.join(imagenet_path,'ILSVRC-%s-%d.lmdb'%(name, i)) for i in range(args.n) ]

        if not args.check_only:            
            if args.n == 1:
                LMDBSerializer.save(ds1, paths[0])
            else:
                print("Saving to %d files:\n%s\n" %(args.n, "\n".join(paths)))
                LMDBSplitSaver.save(ds1, paths, args.n)
                
        orig_total_img_count = len(ds0)
        lmdb_total_img_count = 0
        for i in range(args.n):
            ds = LMDBSerializer.load(paths[i], shuffle=False)
            lmdb_total_img_count += len(ds)
            
        print("'%s' orig: %d, lmdb: %d." %(name, orig_total_img_count, lmdb_total_img_count), end=' ')
        if orig_total_img_count != lmdb_total_img_count:
            print("Mismatch!")
            pdb.set_trace()
        else:
コード例 #8
0
            maximum = np.amax(img, initial=self.min, keepdims=True)

        img = (self.max - self.min) * (img - minimum) / (maximum - minimum) + self.min

        return img


# Testcode for encode/decode.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--create', action='store_true', help='create lmdb')
    parser.add_argument('--debug', action='store_true', help='debug images')
    parser.add_argument('--input', type=str, help='path to coco zip', required=True)
    parser.add_argument('--lmdb', type=str, help='path to output lmdb', required=True)
    args = parser.parse_args()

    ds = ImageDataFromZIPFile(args.input)
    ds = ImageDecodeYCrCb(ds, index=0)
    # ds = RejectTooSmallImages(ds, index=0)
    ds = CenterSquareResize(ds, index=0)
    if args.create:
        ds = ImageEncode(ds, index=0)
        LMDBSerializer.save(ds, args.lmdb)
    if args.debug:
        ds.reset_state()
        for i in ds:
            cv2.imshow('example', i[0])
            cv2.waitKey(0)


コード例 #9
0
import numpy as np
from tensorpack.dataflow import *


class BinaryILSVRC12(dataset.ILSVRC12Files):
    def __iter__(self):
        for fname, label in super(BinaryILSVRC12, self).__iter__():
            with open(fname, 'rb') as f:
                jpeg = f.read()
            jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
            yield [jpeg, label]


from tensorpack.dataflow.serialize import LMDBSerializer
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ds', type=str)
    parser.add_argument('-s', '--split', type=str, default="val")
    parser.add_argument('--out', type=str, default=".")
    parser.add_argument('-p', '--procs', type=int, default=20)

    args = parser.parse_args()

    import os.path as osp
    ds0 = BinaryILSVRC12(args.ds, args.split)
    ds1 = PrefetchDataZMQ(ds0, nr_proc=args.procs)
    LMDBSerializer.save(ds1, osp.join(args.out, '%s.lmdb' % args.split))
コード例 #10
0
import numpy as np
import cv2
import os

from tensorpack.dataflow import *
from tensorpack.dataflow.serialize import LMDBSerializer


class Dummy(DataFlow):

    def get_data(self):

        for cur_id in xrange(1000000):
            yield [cur_id]

    def size(self):
        return 1000000


if __name__ == '__main__':
    output = os.path.expanduser('~/workspace/tmp/dummy_1M.lmdb')
    LMDBSerializer.save(Dummy(), output, write_frequency=100000)