Пример #1
0
class DBFileReader(Reader):

    default_name_suffix = 'db_file_reader'
    """Reader reads from a DB file.

    Example usage:
    db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')

    Args:
        db_path: str.
        db_type: str. DB type of file. A db_type is registed by
            `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
        name: str or None. Name of DBFileReader.
            Optional name to prepend to blobs that will store the data.
            Default to '<db_name>_<default_name_suffix>'.
        batch_size: int.
            How many examples are read for each time the read_net is run.
        loop_over: bool.
            If True given, will go through examples in random order endlessly.
        field_names: List[str]. If the schema.field_names() should not in
            alphabetic order, it must be specified.
            Otherwise, schema will be automatically restored with
            schema.field_names() sorted in alphabetic order.
    """
    def __init__(
        self,
        db_path,
        db_type,
        name=None,
        batch_size=100,
        loop_over=False,
        field_names=None,
    ):
        assert db_path is not None, "db_path can't be None."
        assert db_type in C.registered_dbs(), \
            "db_type [{db_type}] is not available. \n" \
            "Choose one of these: {registered_dbs}.".format(
                db_type=db_type,
                registered_dbs=C.registered_dbs(),
        )

        self.db_path = os.path.expanduser(db_path)
        self.db_type = db_type
        self.name = name or '{db_name}_{default_name_suffix}'.format(
            db_name=self._extract_db_name_from_db_path(),
            default_name_suffix=self.default_name_suffix,
        )
        self.batch_size = batch_size
        self.loop_over = loop_over

        # Before self._init_reader_schema(...),
        # self.db_path and self.db_type are required to be set.
        super(DBFileReader,
              self).__init__(self._init_reader_schema(field_names))
        self.ds = Dataset(self._schema, self.name + '_dataset')
        self.ds_reader = None

    def _init_name(self, name):
        return name or self._extract_db_name_from_db_path() + '_db_file_reader'

    def _init_reader_schema(self, field_names=None):
        """Restore a reader schema from the DB file.

        If `field_names` given, restore scheme according to it.

        Overwise, loade blobs from the DB file into the workspace,
        and restore schema from these blob names.
        It is also assumed that:
        1). Each field of the schema have corresponding blobs
            stored in the DB file.
        2). Each blob loaded from the DB file corresponds to
            a field of the schema.
        3). field_names in the original schema are in alphabetic order,
            since blob names loaded to the workspace from the DB file
            will be in alphabetic order.

        Load a set of blobs from a DB file. From names of these blobs,
        restore the DB file schema using `from_column_list(...)`.

        Returns:
            schema: schema.Struct. Used in Reader.__init__(...).
        """
        if field_names:
            return from_column_list(field_names)

        assert os.path.exists(self.db_path), \
            'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
        with core.NameScope(self.name):
            # blob_prefix is for avoiding name conflict in workspace
            blob_prefix = scope.CurrentNameScope()
        workspace.RunOperatorOnce(
            core.CreateOperator(
                'Load',
                [],
                [],
                absolute_path=True,
                db=self.db_path,
                db_type=self.db_type,
                load_all=True,
                add_prefix=blob_prefix,
            ))
        col_names = [
            blob_name[len(blob_prefix):] for blob_name in workspace.Blobs()
            if blob_name.startswith(blob_prefix)
        ]
        schema = from_column_list(col_names)
        return schema

    def setup_ex(self, init_net, finish_net):
        """From the Dataset, create a _DatasetReader and setup a init_net.

        Make sure the _init_field_blobs_as_empty(...) is only called once.

        Because the underlying NewRecord(...) creats blobs by calling
        NextScopedBlob(...), so that references to previously-initiated
        empty blobs will be lost, causing accessibility issue.
        """
        if self.ds_reader:
            self.ds_reader.setup_ex(init_net, finish_net)
        else:
            self._init_field_blobs_as_empty(init_net)
            self._feed_field_blobs_from_db_file(init_net)
            self.ds_reader = self.ds.random_reader(
                init_net,
                batch_size=self.batch_size,
                loop_over=self.loop_over,
            )
            self.ds_reader.sort_and_shuffle(init_net)
            self.ds_reader.computeoffset(init_net)

    def read(self, read_net):
        assert self.ds_reader, 'setup_ex must be called first'
        return self.ds_reader.read(read_net)

    def _init_field_blobs_as_empty(self, init_net):
        """Initialize dataset field blobs by creating an empty record"""
        with core.NameScope(self.name):
            self.ds.init_empty(init_net)

    def _feed_field_blobs_from_db_file(self, net):
        """Load from the DB file at db_path and feed dataset field blobs"""
        assert os.path.exists(self.db_path), \
            'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
        net.Load(
            [],
            self.ds.get_blobs(),
            db=self.db_path,
            db_type=self.db_type,
            absolute_path=True,
            source_blob_names=self.ds.field_names(),
        )

    def _extract_db_name_from_db_path(self):
        """Extract DB name from DB path

            E.g. given self.db_path=`/tmp/sample.db`,
            it returns `sample`.

            Returns:
                db_name: str.
        """
        return os.path.basename(self.db_path).rsplit('.', 1)[0]
Пример #2
0
    def test_record_queue(self):
        num_prod = 8
        num_consume = 3
        schema = Struct(
            ('floats', Map(
                Scalar(np.int32),
                Scalar(np.float32))),
        )
        contents_raw = [
            [1, 2, 3],  # len
            [11, 21, 22, 31, 32, 33],  # key
            [1.1, 2.1, 2.2, 3.1, 3.2, 3.3],  # value
        ]
        contents = from_blob_list(schema, contents_raw)
        ds = Dataset(schema)
        net = core.Net('init')
        ds.init_empty(net)

        content_blobs = NewRecord(net, contents)
        FeedRecord(content_blobs, contents)
        writer = ds.writer(init_net=net)
        writer.write_record(net, content_blobs)
        reader = ds.reader(init_net=net)

        # prepare receiving dataset
        rec_dataset = Dataset(contents, name='rec')
        rec_dataset.init_empty(init_net=net)
        rec_dataset_writer = rec_dataset.writer(init_net=net)

        workspace.RunNetOnce(net)

        queue = RecordQueue(contents, num_threads=num_prod)

        def process(net, fields):
            new_fields = []
            for f in fields.field_blobs():
                new_f = net.Copy(f)
                new_fields.append(new_f)
            new_fields = from_blob_list(fields, new_fields)
            return new_fields

        q_reader, q_step, q_exit, fields = queue.build(reader, process)
        producer_step = core.execution_step('producer', [q_step, q_exit])

        consumer_steps = []
        for i in range(num_consume):
            name = 'queue_reader_' + str(i)
            net_consume = core.Net(name)
            should_stop, fields = q_reader.read_record(net_consume)
            step_consume = core.execution_step(name, net_consume)

            name = 'dataset_writer_' + str(i)
            net_dataset = core.Net(name)
            rec_dataset_writer.write(net_dataset, fields.field_blobs())
            step_dataset = core.execution_step(name, net_dataset)

            step = core.execution_step(
                'consumer_' + str(i),
                [step_consume, step_dataset],
                should_stop_blob=should_stop)
            consumer_steps.append(step)
        consumer_step = core.execution_step(
            'consumers', consumer_steps, concurrent_substeps=True)

        work_steps = core.execution_step(
            'work', [producer_step, consumer_step], concurrent_substeps=True)

        plan = core.Plan('test')
        plan.AddStep(work_steps)
        core.workspace.RunPlan(plan)
        data = workspace.FetchBlobs(rec_dataset.get_blobs())
        self.assertEqual(6, sum(data[0]))
        self.assertEqual(150, sum(data[1]))
        self.assertAlmostEqual(15, sum(data[2]), places=5)
Пример #3
0
    def test_record_queue(self):
        num_prod = 8
        num_consume = 3
        schema = Struct(
            ('floats', Map(Scalar(np.int32), Scalar(np.float32))), )
        contents_raw = [
            [1, 2, 3],  # len
            [11, 21, 22, 31, 32, 33],  # key
            [1.1, 2.1, 2.2, 3.1, 3.2, 3.3],  # value
        ]
        contents = from_blob_list(schema, contents_raw)
        ds = Dataset(schema)
        net = core.Net('init')
        ds.init_empty(net)

        content_blobs = NewRecord(net, contents)
        FeedRecord(content_blobs, contents)
        writer = ds.writer(init_net=net)
        writer.write_record(net, content_blobs)
        reader = ds.reader(init_net=net)

        # prepare receiving dataset
        rec_dataset = Dataset(contents, name='rec')
        rec_dataset.init_empty(init_net=net)
        rec_dataset_writer = rec_dataset.writer(init_net=net)

        workspace.RunNetOnce(net)

        queue = RecordQueue(contents, num_threads=num_prod)

        def process(net, fields):
            new_fields = []
            for f in fields.field_blobs():
                new_f = net.Copy(f)
                new_fields.append(new_f)
            new_fields = from_blob_list(fields, new_fields)
            return new_fields

        q_reader, q_step, q_exit, fields = queue.build(reader, process)
        producer_step = core.execution_step('producer', [q_step, q_exit])

        consumer_steps = []
        for i in range(num_consume):
            name = 'queue_reader_' + str(i)
            net_consume = core.Net(name)
            should_stop, fields = q_reader.read_record(net_consume)
            step_consume = core.execution_step(name, net_consume)

            name = 'dataset_writer_' + str(i)
            net_dataset = core.Net(name)
            rec_dataset_writer.write(net_dataset, fields.field_blobs())
            step_dataset = core.execution_step(name, net_dataset)

            step = core.execution_step('consumer_' + str(i),
                                       [step_consume, step_dataset],
                                       should_stop_blob=should_stop)
            consumer_steps.append(step)
        consumer_step = core.execution_step('consumers',
                                            consumer_steps,
                                            concurrent_substeps=True)

        work_steps = core.execution_step('work',
                                         [producer_step, consumer_step],
                                         concurrent_substeps=True)

        plan = core.Plan('test')
        plan.AddStep(work_steps)
        core.workspace.RunPlan(plan)
        data = workspace.FetchBlobs(rec_dataset.get_blobs())
        self.assertEqual(6, sum(data[0]))
        self.assertEqual(150, sum(data[1]))
        self.assertAlmostEqual(15, sum(data[2]), places=5)
Пример #4
0
class DBFileReader(Reader):

    default_name_suffix = 'db_file_reader'

    """Reader reads from a DB file.

    Example usage:
    db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')

    Args:
        db_path: str.
        db_type: str. DB type of file. A db_type is registed by
            `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
        name: str or None. Name of DBFileReader.
            Optional name to prepend to blobs that will store the data.
            Default to '<db_name>_<default_name_suffix>'.
        batch_size: int.
            How many examples are read for each time the read_net is run.
    """
    def __init__(
        self,
        db_path,
        db_type,
        name=None,
        batch_size=100,
    ):
        assert db_path is not None, "db_path can't be None."
        assert db_type in C.registered_dbs(), \
            "db_type [{db_type}] is not available. \n" \
            "Choose one of these: {registered_dbs}.".format(
                db_type=db_type,
                registered_dbs=C.registered_dbs(),
        )

        self.db_path = db_path
        self.db_type = db_type
        self.name = name or '{db_name}_{default_name_suffix}'.format(
            db_name=self._extract_db_name_from_db_path(),
            default_name_suffix=self.default_name_suffix,
        )
        self.batch_size = batch_size

        # Before self._init_reader_schema(...),
        # self.db_path and self.db_type are required to be set.
        super(DBFileReader, self).__init__(self._init_reader_schema())
        self.ds = Dataset(self._schema, self.name + '_dataset')
        self.ds_reader = None

    def _init_name(self, name):
        return name or self._extract_db_name_from_db_path(
        ) + '_db_file_reader'

    def _init_reader_schema(self):
        """Restore a reader schema from the DB file.

        Here it is assumed that:
        1). Each field of the schema have corresponding blobs
            stored in the DB file.
        2). Each blob loaded from the DB file corresponds to
            a field of the schema.

        Load a set of blobs from a DB file. From names of these blobs,
        restore the DB file schema using `from_column_list(...)`.

        Returns:
            schema: schema.Struct. Used in Reader.__init__(...).
        """
        assert os.path.exists(self.db_path), \
            'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
        with core.NameScope(self.name):
            # blob_prefix is for avoiding name conflict in workspace
            blob_prefix = scope.CurrentNameScope()
        workspace.RunOperatorOnce(
            core.CreateOperator(
                'Load',
                [],
                [],
                absolute_path=True,
                db=self.db_path,
                db_type=self.db_type,
                load_all=True,
                add_prefix=blob_prefix,
            )
        )
        col_names = [
            blob_name[len(blob_prefix):] for blob_name in workspace.Blobs()
            if blob_name.startswith(blob_prefix)
        ]
        schema = from_column_list(col_names)
        return schema

    def setup_ex(self, init_net, finish_net):
        """From the Dataset, create a _DatasetReader and setup a init_net.

        Make sure the _init_field_blobs_as_empty(...) is only called once.

        Because the underlying NewRecord(...) creats blobs by calling
        NextScopedBlob(...), so that references to previously-initiated
        empty blobs will be lost, causing accessibility issue.
        """
        if self.ds_reader:
            self.ds_reader.setup_ex(init_net, finish_net)
        else:
            self._init_field_blobs_as_empty(init_net)
            self._feed_field_blobs_from_db_file(init_net)
            self.ds_reader = self.ds.reader(
                init_net,
                batch_size=self.batch_size,
            )

    def read(self, read_net):
        assert self.ds_reader, 'setup_ex must be called first'
        return self.ds_reader.read(read_net)

    def _init_field_blobs_as_empty(self, init_net):
        """Initialize dataset field blobs by creating an empty record"""
        with core.NameScope(self.name):
            self.ds.init_empty(init_net)

    def _feed_field_blobs_from_db_file(self, net):
        """Load from the DB file at db_path and feed dataset field blobs"""
        assert os.path.exists(self.db_path), \
            'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
        net.Load(
            [],
            self.ds.get_blobs(),
            db=self.db_path,
            db_type=self.db_type,
            absolute_path=True,
            source_blob_names=self.ds.field_names(),
        )

    def _extract_db_name_from_db_path(self):
        """Extract DB name from DB path

            E.g. given self.db_path=`/tmp/sample.db`,
            it returns `sample`.

            Returns:
                db_name: str.
        """
        return os.path.basename(self.db_path).rsplit('.', 1)[0]