示例#1
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = DmdsParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "depth", "driving_stereo")
        scenes = [
            "2018-10-26-15-24-18",
            "2018-10-19-09-30-39",
        ]
        self.train_data = []
        self.val_data = []
        self.collection_details = []

        # get ids
        for scene_token in scenes:
            td, vd = load_ids(collection_details,
                              data_split=(80, 20),
                              limit=100,
                              shuffle_data=False,
                              mongodb_filter={"scene_token": scene_token},
                              sort_by={"timestamp": 1})
            self.train_data.append(td)
            self.val_data.append(vd)
            self.collection_details.append(collection_details)
示例#2
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = CenternetParams(len(OD_CLASS_MAPPING))
        self.params.REGRESSION_FIELDS["l_shape"].active = True
        self.params.REGRESSION_FIELDS["3d_info"].active = True

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=250)
示例#3
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = SemsegParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "comma10k")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=30)
示例#4
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = MultitaskParams(len(OD_CLASS_MAPPING.items()))

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.td, self.vd = load_ids(
            self.collection_details,
            data_split=(70, 30),
            shuffle_data=True,
            limit=30
        )
示例#5
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = CentertrackerParams(len(OD_CLASS_MAPPING))

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "kitti")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(
            self.collection_details,
            data_split=(70, 30),
            limit=100
        )
示例#6
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = Params()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # get ids
        td, vd = load_ids(
            collection_details,
            data_split=(70, 30),
            limit=100,
            shuffle_data=True,
        )
        self.train_data = [td]
        self.val_data = [vd]
        self.collection_details = [collection_details]
示例#7
0
        show_pygame.show(x, y_true, y_pred)
        return result
    return call_custom_callbacks

if __name__ == "__main__":
    Logger.init()
    Logger.remove_file_logger()

    params = MultitaskParams(len(OD_CLASS_MAPPING.items()))

    Config.add_config('./config.ini')
    con = ("local_mongodb", "labels", "nuscenes_train")

    td, vd = load_ids(
        con,
        data_split=(90, 10),
        shuffle_data=True
    )

    train_data = [td]
    val_data = [vd]
    collection_details = [con]

    train_gen = MongoDBGenerator(
        collection_details,
        train_data,
        batch_size=params.BATCH_SIZE,
        processors=[ProcessImages(params, [0, 0])],
        shuffle_data=True
    )
    val_gen = MongoDBGenerator(
示例#8
0
        # "2018-10-11-17-08-31",
        # "2018-08-13-17-45-03",
        # "2018-08-13-15-32-19",
        # "2018-07-31-11-22-31",
        # "2018-07-31-11-07-48",
    ]
    train_data = []
    val_data = []
    collection_details = []

    # get ids
    for scene_token in scenes:
        td, vd = load_ids(
            con,
            data_split=(95, 5),
            shuffle_data=False,
            mongodb_filter={"scene_token": scene_token},
            sort_by={"timestamp": 1}
        )
        train_data.append(td)
        val_data.append(vd)
        collection_details.append(con)

    processors = [ProcessImages(params)]
    train_gen = MongoDBGenerator(
        collection_details,
        train_data,
        batch_size=params.BATCH_SIZE,
        processors=processors,
        data_group_size=2,
        continues_data_selection=True,
示例#9
0
from models.depth import create_model, Params, ProcessImages
from models.depth.loss import DepthLoss
from common.utils import set_weights

if __name__ == "__main__":
    Logger.init()
    Logger.remove_file_logger()

    params = Params()

    # get one entry from the database
    Config.add_config('./config.ini')
    con = ("local_mongodb", "labels", "driving_stereo")
    # con = ("local_mongodb", "labels", "nuscenes_train")

    td, vd = load_ids(con, data_split=(87, 13), shuffle_data=True)
    train_data = [td]
    val_data = [vd]
    collection_details = [con]

    train_gen = MongoDBGenerator(collection_details,
                                 train_data,
                                 batch_size=params.BATCH_SIZE,
                                 processors=[ProcessImages(params, [0, 0])],
                                 shuffle_data=True)
    val_gen = MongoDBGenerator(collection_details,
                               val_data,
                               batch_size=params.BATCH_SIZE,
                               processors=[ProcessImages(params, False)],
                               shuffle_data=True)
示例#10
0
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4864)])

if __name__ == "__main__":
    Logger.init()
    Logger.remove_file_logger()

    params = CentertrackerParams(len(OD_CLASS_MAPPING))
    params.REGRESSION_FIELDS["l_shape"].active = False
    params.REGRESSION_FIELDS["3d_info"].active = False

    Config.add_config('./config.ini')
    collection_details = ("local_mongodb", "labels", "kitti")

    # Create Data Generators
    train_data, val_data = load_ids(collection_details,
                                    data_split=(82, 18),
                                    shuffle_data=True)

    processors = [CenterTrackerProcess(params)]
    train_gen = MongoDBGenerator(collection_details,
                                 train_data,
                                 batch_size=params.BATCH_SIZE,
                                 processors=processors)
    val_gen = MongoDBGenerator(collection_details,
                               val_data,
                               batch_size=params.BATCH_SIZE,
                               processors=processors)

    loss = CenterTrackerProcess(params)
    # metrics = [loss.class_focal_loss, loss.r_offset_loss, loss.fullbox_loss]
    metrics = []