示例#1
0
    def train(self,
              img_folder,
              ann_folder,
              nb_epoch,
              saved_weights_name,
              batch_size=8,
              jitter=True,
              learning_rate=1e-4,
              train_times=1,
              valid_times=1,
              valid_img_folder="",
              valid_ann_folder="",
              first_trainable_layer=None,
              is_only_detect=False):

        # 1. get annotations
        train_annotations, valid_annotations = get_train_annotations(
            self._labels, img_folder, ann_folder, valid_img_folder,
            valid_ann_folder, is_only_detect)
        print("No error getting annotations")
        # 1. get batch generator
        train_batch_generator = self._get_batch_generator(train_annotations,
                                                          batch_size,
                                                          train_times,
                                                          jitter=jitter)
        valid_batch_generator = self._get_batch_generator(valid_annotations,
                                                          batch_size,
                                                          valid_times,
                                                          jitter=False)
        print("No error getting batch generator")

        # 2. To train model get keras model instance & loss fucntion
        model = self._yolo_network.get_model(first_trainable_layer)
        loss = self._get_loss_func(batch_size)
        print("No error getting keras model instance & loss fucntion")

        # 3. Run training loop
        train(model,
              loss,
              train_batch_generator,
              valid_batch_generator,
              learning_rate=learning_rate,
              nb_epoch=nb_epoch,
              saved_weights_name=saved_weights_name)
        print("No error running training loop")
示例#2
0
    def train(self,
              img_folder,
              ann_folder,
              nb_epoch,
              saved_weights_name,
              batch_size=8,
              jitter=True,
              learning_rate=1e-4, 
              train_times=1,
              valid_times=1,
              warmup_epochs=None,
              valid_img_folder="",
              valid_ann_folder="",
              first_trainable_layer=None,
              is_only_detect=False):

        # 1. get annotations        
        train_annotations, valid_annotations = get_train_annotations(self._labels,
                                                                     img_folder,
                                                                     ann_folder,
                                                                     valid_img_folder,
                                                                     valid_ann_folder,
                                                                     is_only_detect)
        
        # 1. get batch generator
        train_batch_generator = self._get_batch_generator(train_annotations, batch_size, train_times, jitter=jitter)
        valid_batch_generator = self._get_batch_generator(valid_annotations, batch_size, valid_times, jitter=False)
        
        # 2. To train model get keras model instance & loss fucntion
        model = self._yolo_network.get_model(first_trainable_layer)
        loss = self._get_loss_func(batch_size,
                                  warmup_epochs,
                                  train_times,
                                  valid_times)
        
        # 3. Run training loop
        train(model,
                loss,
                train_batch_generator,
                valid_batch_generator,
                learning_rate      = learning_rate, 
                nb_epoch           = nb_epoch,
                saved_weights_name = saved_weights_name)
示例#3
0
    def train(
            self,
            img_folder,
            ann_folder,
            img_in_mem,  # datasets in mem, format: list
            ann_in_mem,  # datasets's annotation in mem, format: list
            nb_epoch,
            save_best_weights_path,
            save_final_weights_path,
            batch_size=8,
            jitter=True,
            learning_rate=1e-4,
            train_times=1,
            valid_times=1,
            valid_img_folder="",
            valid_ann_folder="",
            valid_img_in_mem=None,
            valid_ann_in_mem=None,
            first_trainable_layer=None,
            is_only_detect=False,
            progress_callbacks=[]):
        # 1. get annotations
        train_annotations, valid_annotations = get_train_annotations(
            self._labels, img_folder, ann_folder, valid_img_folder,
            valid_ann_folder, img_in_mem, ann_in_mem, valid_img_in_mem,
            valid_ann_in_mem, is_only_detect)

        # 1. get batch generator
        train_batch_generator = self._get_batch_generator(train_annotations,
                                                          batch_size,
                                                          train_times,
                                                          jitter=jitter)
        valid_batch_generator = self._get_batch_generator(valid_annotations,
                                                          batch_size,
                                                          valid_times,
                                                          jitter=False)

        # 2. To train model get keras model instance & loss function
        model = self._yolo_network.get_model(first_trainable_layer)
        loss = self._get_loss_func(batch_size)

        # 3. Run training loop
        history = train(model,
                        loss,
                        train_batch_generator,
                        valid_batch_generator,
                        learning_rate=learning_rate,
                        nb_epoch=nb_epoch,
                        save_best_weights_path=save_best_weights_path,
                        save_final_weights_path=save_final_weights_path,
                        progress_callbacks=progress_callbacks)
        return history