示例#1
0
    def __init__(
        self,
        input_shape,
        model_configuration,
        classes_file,
        anchors=None,
        masks=None,
        max_boxes=100,
        iou_threshold=0.5,
        score_threshold=0.5,
    ):
        """
        Initialize detection settings.

        Args:
            input_shape: tuple, (n, n, c)
            model_configuration: Path to DarkNet cfg file.
            classes_file: File containing class names \n delimited.
            anchors: numpy array of (w, h) pairs.
            masks: numpy array of masks.
            max_boxes: Maximum boxes of the TFRecords provided(if any) or
                maximum boxes setting.
            iou_threshold: float, values less than the threshold are ignored.
            score_threshold: float, values less than the threshold are ignored.
        """
        self.class_names = [
            item.strip() for item in open(classes_file).readlines()
        ]
        self.box_colors = {
            class_name: color
            for class_name, color in zip(
                self.class_names,
                [
                    list(np.random.random(size=3) * 256)
                    for _ in range(len(self.class_names))
                ],
            )
        }
        super().__init__(
            input_shape=input_shape,
            model_configuration=model_configuration,
            classes=len(self.class_names),
            anchors=anchors,
            masks=masks,
            max_boxes=max_boxes,
            iou_threshold=iou_threshold,
            score_threshold=score_threshold,
        )
        activate_gpu()
    def train(
        self,
        epochs,
        batch_size,
        learning_rate,
        new_anchors_conf=None,
        new_dataset_conf=None,
        dataset_name=None,
        weights=None,
        evaluate=True,
        merge_evaluation=True,
        evaluation_workers=8,
        shuffle_buffer=512,
        min_overlaps=None,
        display_stats=True,
        plot_stats=True,
        save_figs=True,
        clear_outputs=False,
        n_epoch_eval=None,
    ):
        """
        Train on the dataset.
        Args:
            epochs: Number of training epochs.
            batch_size: Training batch size.
            learning_rate: non-negative value.
            new_anchors_conf: A dictionary containing anchor generation configuration.
            new_dataset_conf: A dictionary containing dataset generation configuration.
            dataset_name: Name of the dataset for model checkpoints.
            weights: .tf or .weights file
            evaluate: If False, the trained model will not be evaluated after training.
            merge_evaluation: If False, training and validation maps will
                be calculated separately.
            evaluation_workers: Parallel predictions.
            shuffle_buffer: Buffer size for shuffling datasets.
            min_overlaps: a float value between 0 and 1, or a dictionary
                containing each class in self.class_names mapped to its
                minimum overlap
            display_stats: If True and evaluate=True, evaluation statistics will be displayed.
            plot_stats: If True, Precision and recall curves as well as
                comparative bar charts will be plotted
            save_figs: If True and plot_stats=True, figures will be saved
            clear_outputs: If True, old outputs will be cleared
            n_epoch_eval: Conduct evaluation every n epoch.

        Returns:
            history object, pandas DataFrame with statistics, mAP score.
        """
        min_overlaps = min_overlaps or 0.5
        if clear_outputs:
            self.clear_outputs()
        activate_gpu()
        default_logger.info(f'Starting training ...')
        if new_anchors_conf:
            default_logger.info(f'Generating new anchors ...')
            self.generate_new_anchors(new_anchors_conf)
        self.create_models()
        if weights:
            self.load_weights(weights)
        if new_dataset_conf:
            self.create_new_dataset(new_dataset_conf)
        self.check_tf_records()
        training_dataset = self.initialize_dataset(self.train_tf_record,
                                                   batch_size, shuffle_buffer)
        valid_dataset = self.initialize_dataset(self.valid_tf_record,
                                                batch_size, shuffle_buffer)
        optimizer = tf.keras.optimizers.Adam(learning_rate)
        loss = [
            calculate_loss(self.anchors[mask], self.classes,
                           self.iou_threshold) for mask in self.masks
        ]
        self.training_model.compile(optimizer=optimizer, loss=loss)
        checkpoint_name = os.path.join(
            '..', 'Models', f'{dataset_name or "trained"}_model.tf')
        callbacks = self.create_callbacks(checkpoint_name)
        if n_epoch_eval:
            mid_train_eval = MidTrainingEvaluator(
                self.input_shape,
                self.classes_file,
                self.image_width,
                self.image_height,
                self.train_tf_record,
                self.valid_tf_record,
                self.anchors,
                self.masks,
                self.max_boxes,
                self.iou_threshold,
                self.score_threshold,
                n_epoch_eval,
                merge_evaluation,
                evaluation_workers,
                shuffle_buffer,
                min_overlaps,
                display_stats,
                plot_stats,
                save_figs,
                checkpoint_name,
            )
            callbacks.append(mid_train_eval)
        history = self.training_model.fit(
            training_dataset,
            epochs=epochs,
            callbacks=callbacks,
            validation_data=valid_dataset,
        )
        default_logger.info('Training complete')
        if evaluate:
            evaluations = self.evaluate(
                checkpoint_name,
                merge_evaluation,
                evaluation_workers,
                shuffle_buffer,
                min_overlaps,
                display_stats,
                plot_stats,
                save_figs,
            )
            return evaluations, history
        return history