Exemplo n.º 1
0
def train():
    config = DeepFashionTrainConfig()
    config.display()
    model = MaskRCNN(mode="training", config=config, model_dir="checkpoints")
    # Load coco pretrained weights
    coco_weights = "models/mask_rcnn_coco.h5"
    # As we don't want to detect default coco classes, we want to exclude the last layers.
    # If not model requires matching number of coco classes
    model.load_weights(coco_weights,
                       by_name=True,
                       exclude=[
                           "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox",
                           "mrcnn_mask"
                       ])

    dataset_train = DeepFashionDataset()
    dataset_train.load_coco(config.TRAIN_IMG_DIR,
                            config.TRAIN_ANNOTATIONS_PATH)
    dataset_train.prepare()

    dataset_valid = DeepFashionDataset()
    dataset_valid.load_coco(config.VALID_IMG_DIR,
                            config.VALID_ANNOTATIONS_PATH)
    dataset_valid.prepare()

    model.train(dataset_train,
                dataset_valid,
                learning_rate=config.LEARNING_RATE,
                epochs=30,
                layers='heads')

    print("Finish")
Exemplo n.º 2
0
# Подготавливаем тренировочный и тестовый набор данных
# train_set = DataSetFactory.new_instance('data/img')
# test_set = DataSetFactory.new_instance('data/val')
train_set = CocoSetFactory.new_instance(COCO_PATH, 'train', 2017, CLASS_IDS)
test_set = CocoSetFactory.new_instance(COCO_PATH, 'val', 2017, CLASS_IDS)

# Определяется конфигурация
config = DetectConfig()
config.display()
# Определяется модель
model = MaskRCNN(mode='training', model_dir='./', config=config)
# Произвести загрузку стартовой модели
model.load_weights(START_MODEL_PATH,
                   by_name=True,
                   exclude=[
                       "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox",
                       "mrcnn_mask"
                   ])

# augmentation = imgaug.augmenters.Fliplr(0.5)

# Запуск тренировки тестовой модели
model.train(train_set,
            test_set,
            learning_rate=config.LEARNING_RATE,
            epochs=EPOCHS,
            layers='heads'
            # augmentation=augmentation
            )