コード例 #1
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = DmdsParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "depth", "driving_stereo")
        scenes = [
            "2018-10-26-15-24-18",
            "2018-10-19-09-30-39",
        ]
        self.train_data = []
        self.val_data = []
        self.collection_details = []

        # get ids
        for scene_token in scenes:
            td, vd = load_ids(collection_details,
                              data_split=(80, 20),
                              limit=100,
                              shuffle_data=False,
                              mongodb_filter={"scene_token": scene_token},
                              sort_by={"timestamp": 1})
            self.train_data.append(td)
            self.val_data.append(vd)
            self.collection_details.append(collection_details)
コード例 #2
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = CenternetParams(len(OD_CLASS_MAPPING))
        self.params.REGRESSION_FIELDS["l_shape"].active = True
        self.params.REGRESSION_FIELDS["3d_info"].active = True

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=250)
コード例 #3
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = SemsegParams()

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "comma10k")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(self.collection_details,
                                                  data_split=(70, 30),
                                                  limit=30)
コード例 #4
0
    def setup_method(self):
        Logger.init()
        Logger.remove_file_logger()

        self.params = MultitaskParams(len(OD_CLASS_MAPPING.items()))

        # get one entry from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # Create Data Generators
        self.td, self.vd = load_ids(
            self.collection_details,
            data_split=(70, 30),
            shuffle_data=True,
            limit=30
        )
コード例 #5
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = CentertrackerParams(len(OD_CLASS_MAPPING))

        # get some entries from the database
        Config.add_config('./config.ini')
        self.collection_details = ("local_mongodb", "labels", "kitti")

        # Create Data Generators
        self.train_data, self.val_data = load_ids(
            self.collection_details,
            data_split=(70, 30),
            limit=100
        )
コード例 #6
0
    def setup_method(self):
        """
        Set up parameters for test methods
        """
        Logger.init()
        Logger.remove_file_logger()

        self.params = Params()

        # get one entry from the database
        Config.add_config('./config.ini')
        collection_details = ("local_mongodb", "labels", "nuscenes_train")

        # get ids
        td, vd = load_ids(
            collection_details,
            data_split=(70, 30),
            limit=100,
            shuffle_data=True,
        )
        self.train_data = [td]
        self.val_data = [vd]
        self.collection_details = [collection_details]
コード例 #7
0
        data = data_adapter.expand_1d(original_data)
        x, y_true, w = data_adapter.unpack_x_y_sample_weight(data)
        y_pred = keras_model(x, training=True)
        result = original_train_step(original_data)
        # custom stuff called during training
        show_pygame.show(x, y_true, y_pred)
        return result
    return call_custom_callbacks

if __name__ == "__main__":
    Logger.init()
    Logger.remove_file_logger()

    params = MultitaskParams(len(OD_CLASS_MAPPING.items()))

    Config.add_config('./config.ini')
    con = ("local_mongodb", "labels", "nuscenes_train")

    td, vd = load_ids(
        con,
        data_split=(90, 10),
        shuffle_data=True
    )

    train_data = [td]
    val_data = [vd]
    collection_details = [con]

    train_gen = MongoDBGenerator(
        collection_details,
        train_data,