Ejemplo n.º 1
0
 def create_generator(self):
     train_ann_fnames = self._get_train_anns()
     valid_ann_fnames = self._get_valid_anns()
 
     train_generator = BatchGenerator(train_ann_fnames,
                                      self._train_config["train_image_folder"],
                                      batch_size=self._train_config["batch_size"],
                                      labels=self._model_config["labels"],
                                      anchors=self._model_config["anchors"],
                                      min_net_size=self._train_config["min_size"],
                                      max_net_size=self._train_config["max_size"],
                                      jitter=self._train_config["jitter"],
                                      shuffle=True)
     if len(valid_ann_fnames) > 0:
         valid_generator = BatchGenerator(valid_ann_fnames,
                                            self._train_config["valid_image_folder"],
                                            batch_size=self._train_config["batch_size"],
                                            labels=self._model_config["labels"],
                                            anchors=self._model_config["anchors"],
                                            min_net_size=self._model_config["net_size"],
                                            max_net_size=self._model_config["net_size"],
                                            jitter=False,
                                            shuffle=False)
     else:
         valid_generator = None
     print("Training samples : {}, Validation samples : {}".format(len(train_ann_fnames), len(valid_ann_fnames)))
     return train_generator, valid_generator
Ejemplo n.º 2
0
def setup_generator(setup_train_dirs):
    ann_fnames, image_root = setup_train_dirs
    generator = BatchGenerator(ann_fnames,
                               image_root,
                               batch_size=2,
                               labels=["raccoon"],
                               min_net_size=288,
                               max_net_size=288,
                               jitter=False)
    return generator
Ejemplo n.º 3
0
    def create_generator(self, split_train_valid=False):
        train_ann_fnames = self._get_train_anns()
        valid_ann_fnames = self._get_valid_anns()
        img_folder = 'valid_image_folder'

        if split_train_valid:
            train_valid_split = int(0.8 * len(train_ann_fnames))
            np.random.seed(55)
            np.random.shuffle(train_ann_fnames)
            np.random.seed()

            img_folder = 'train_image_folder'
            train_ann_fnames, valid_ann_fnames = train_ann_fnames[:train_valid_split], train_ann_fnames[
                train_valid_split:]
            # valid_generator = None

        valid_generator = BatchGenerator(
            valid_ann_fnames,
            self._train_config[img_folder],
            batch_size=self._train_config['batch_size'],
            labels=self._model_config['labels'],
            anchors=self._model_config['anchors'],
            min_net_size=self._model_config['net_size'],
            max_net_size=self._model_config['net_size'],
            jitter=False,
            shuffle=False)

        train_generator = BatchGenerator(
            train_ann_fnames,
            self._train_config['train_image_folder'],
            batch_size=self._train_config['batch_size'],
            labels=self._model_config['labels'],
            anchors=self._model_config['anchors'],
            min_net_size=self._train_config['min_size'],
            max_net_size=self._train_config['max_size'],
            jitter=self._train_config['jitter'],
            shuffle=True)

        print('Training samples : {}, Validation samples : {}'.format(
            len(train_ann_fnames), len(valid_ann_fnames)))
        return train_generator, valid_generator
Ejemplo n.º 4
0
def test_train(setup_tf_eager, setup_darknet_weights, setup_train_dirs):

    ann_fnames, image_root = setup_train_dirs
    darknet_weights = setup_darknet_weights

    # 1. create generator
    generator = BatchGenerator(ann_fnames,
                               image_root,
                               batch_size=2,
                               labels_naming=["raccoon"],
                               jitter=False)
    valid_generator = BatchGenerator(ann_fnames,
                                     image_root,
                                     batch_size=2,
                                     labels_naming=["raccoon"],
                                     jitter=False)

    # 2. create model
    model = Yolonet(n_classes=1)
    model.load_darknet_params(darknet_weights, True)

    # 3. training
    loss_history = train_fn(model, generator, valid_generator, num_epoches=3)
    assert loss_history[0] > loss_history[-1]
Ejemplo n.º 5
0
    '--config',
    default="configs/raccoon.json",
    help='config file')


if __name__ == '__main__':
    args = argparser.parse_args()
    with open(args.config) as data_file:    
        config = json.load(data_file)
    
    # 1. create generator
    ann_fnames = glob.glob(os.path.join(config["train"]["train_annot_folder"], "*.xml"))
    print(len(ann_fnames))
    train_generator = BatchGenerator(ann_fnames,
                                     config["train"]["train_image_folder"],
                                     batch_size=config["train"]["batch_size"],
                                     labels_naming=config["model"]["labels"],
                                     anchors=config["model"]["anchors"],
                                     jitter=config["train"]["jitter"])

    valid_generator = BatchGenerator(ann_fnames,
                                       config["train"]["train_image_folder"],
                                       batch_size=config["train"]["batch_size"],
                                       labels_naming=config["model"]["labels"],
                                       anchors=config["model"]["anchors"],
                                       jitter=False,
                                       shuffle=False)
    
    print(train_generator.steps_per_epoch)
    
    # 2. create model
    model = Yolonet(n_classes=len(config["model"]["labels"]))