コード例 #1
0
ファイル: deepforest.py プロジェクト: waynechao128/DeepForest
    def train(self, annotations, comet_experiment=None):
        '''Train a deep learning tree detection model using keras-retinanet
        This is the main entry point for training a new model based on either existing weights or scratch
        
        Args:
            annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
            comet_experiment: A comet ml object to log images. Optional.
        Returns:
            model (object): A trained keras model
        '''
        arg_list = utilities.format_args(annotations, self.config)

        print("Training retinanet with the following args {}".format(arg_list))

        #Train model
        self.training_model = retinanet_train(arg_list, comet_experiment)

        #Create prediction model
        self.prediction_model = convert_model(self.training_model)
コード例 #2
0
    def train(self,
              annotations,
              input_type="fit_generator",
              list_of_tfrecords=None,
              comet_experiment=None,
              images_per_epoch=None):
        """Train a deep learning tree detection model using keras-retinanet.
        This is the main entry point for training a new model based on either
        existing weights or scratch.

        Args:
            annotations (str): Path to csv label file,
                labels are in the format -> path/to/image.png,x1,y1,x2,y2,class_name
            input_type: "fit_generator" or "tfrecord"
            list_of_tfrecords: Ignored if input_type != "tfrecord",
                list of tf records to process
            comet_experiment: A comet ml object to log images. Optional.
            images_per_epoch: number of images to override default config
                of images in annotations file / batch size. Useful for debug

        Returns:
            model (object): A trained keras model
            prediction model: with bbox nms
                trained model: without nms
        """
        # Test if there is a new classes file in case # of classes has changed.
        self.classes_file = utilities.create_classes(annotations)
        self.read_classes()
        arg_list = utilities.format_args(annotations, self.classes_file, self.config,
                                         images_per_epoch)

        print("Training retinanet with the following args {}".format(arg_list))

        # Train model
        self.model, self.prediction_model, self.training_model = retinanet_train(
            forest_object=self,
            args=arg_list,
            input_type=input_type,
            list_of_tfrecords=list_of_tfrecords,
            comet_experiment=comet_experiment)