示例#1
0
# 设置路径
images_path = '../datasets/imdb_crop/'
log_file_path = '../datasets/trained_models/gender_models/gender_training.log'
trained_models_path = '../datasets/trained_models/gender_models/gender_mini_XCEPTION'

# 加载数据
data_loader = DataManager(dataset_name)
ground_truth_data = data_loader.get_data()
train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split)
print('Number of training samples:', len(train_keys))
print('Number of validation samples:', len(val_keys))
image_generator = ImageGenerator(ground_truth_data,
                                 batch_size,
                                 input_shape[:2],
                                 train_keys,
                                 val_keys,
                                 None,
                                 path_prefix=images_path,
                                 vertical_flip_probability=0,
                                 grayscale=grayscale,
                                 do_random_crop=do_random_crop)

# 模型参数
model = mini_XCEPTION(input_shape, num_classes)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

# 模型回调
early_stop = EarlyStopping(monitor='val_loss', patience=patience)
reduce_lr = ReduceLROnPlateau(monitor='val_loss',
from utils.data_augmentation import ImageGenerator
from utils.boxes import create_prior_boxes

batch_size = 3
num_epochs = 2
image_shape = (300, 300, 3)

dataset_manager = DataManager(['VOC2007', 'VOC2012'], ['trainval', 'trainval'])
train_data = dataset_manager.load_data()
val_data = test_data = DataManager('VOC2007', 'test').load_data()
class_names = dataset_manager.class_names
num_classes = len(class_names)


prior_boxes = create_prior_boxes()
generator = ImageGenerator(train_data, val_data, prior_boxes, batch_size)

weights_path = '../trained_models/SSD300_weights.hdf5'
frozen_layers = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
                 'conv2_1', 'conv2_2', 'pool2',
                 'conv3_1', 'conv3_2', 'conv3_3', 'pool3']

model = SSD300(image_shape, num_classes, weights_path, frozen_layers, True)
multibox_loss = MultiboxLoss(num_classes, neg_pos_ratio=2.0).compute_loss
model.compile(Adam(lr=3e-4), loss=multibox_loss)


# callbacks
model_path = '../trained_models/SSD_scratch/'

if not os.path.exists(model_path):
def train_gender_classifier():
    # 参数设置
    batch_size = 32  # 批量训练数据大小
    epochs = 10000  # 训练轮数
    input_shape = (64, 64, 1)  # 图片矩阵
    validation_split = .2  # 验证集大小
    num_classes = 2  # 类数,男和女
    patience = 100  # 信心值,用于后面的EarlyStopping等,在信心值个epochs过去后模型性能不再提升,就执行指定动作
    dataset_name = 'imdb'  # 数据集名称
    do_random_crop = False
    # 设为灰度图
    if input_shape[2] == 1:
        grayscale = True
    # 设置路径
    images_path = '../datasets/imdb_crop/'
    log_file_path = '../datasets/trained_models/gender_models/gender_training.log'
    trained_models_path = '../datasets/trained_models/gender_models/gender_mini_XCEPTION'

    # 加载数据
    data_loader = DataManager(dataset_name)
    ground_truth_data = data_loader.get_data()
    train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split)
    print('Number of training samples:', len(train_keys))
    print('Number of validation samples:', len(val_keys))
    image_generator = ImageGenerator(ground_truth_data,
                                     batch_size,
                                     input_shape[:2],
                                     train_keys,
                                     val_keys,
                                     None,
                                     path_prefix=images_path,
                                     vertical_flip_probability=0,
                                     grayscale=grayscale,
                                     do_random_crop=do_random_crop)

    # 模型参数
    model = mini_XCEPTION(input_shape, num_classes)
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.summary()

    # 模型回调
    early_stop = EarlyStopping(monitor='val_loss', patience=patience)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.1,
                                  patience=int(patience / 2),
                                  verbose=1)
    csv_logger = CSVLogger(log_file_path, append=False)
    model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5'
    model_checkpoint = ModelCheckpoint(model_names,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=False)
    callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]

    # 训练模型
    model.fit_generator(image_generator.flow(mode='train'),
                        steps_per_epoch=int(len(train_keys) / batch_size),
                        epochs=epochs,
                        verbose=1,
                        callbacks=callbacks,
                        validation_data=image_generator.flow('val'),
                        validation_steps=int(len(val_keys) / batch_size))
示例#4
0
ground_truth_data = data_manager.get_data()
train_keys, validation_keys = split_data(ground_truth_data, training_ratio=.8)

# instantiating model
model = SSD300(image_shape, num_classes, weights_path, frozen_layers)
multibox_loss = MultiboxLoss(num_classes, neg_pos_ratio=2.0).compute_loss
model.compile(optimizer, loss=multibox_loss, metrics=['acc'])

# setting parameters for data augmentation generator
prior_boxes = create_prior_boxes(model)
image_generator = ImageGenerator(ground_truth_data,
                                 prior_boxes,
                                 num_classes,
                                 box_scale_factors,
                                 batch_size,
                                 image_shape[0:2],
                                 train_keys,
                                 validation_keys,
                                 image_prefix,
                                 vertical_flip_probability=0.5,
                                 horizontal_flip_probability=0.5)

# instantiating callbacks
learning_rate_schedule = LearningRateScheduler(scheduler)
model_names = (trained_models_filename)
model_checkpoint = ModelCheckpoint(model_names,
                                   monitor='val_loss',
                                   verbose=1,
                                   save_best_only=False,
                                   save_weights_only=False)
def main():
    # parameters
    param = args()
    batch_size = param.batch_size
    num_epochs = param.num_epochs
    validation_split = param.val_ratio
    do_random_crop = False
    patience = param.patience
    dataset_name = param.dataset_name
    grayscale = param.graymode
    mode = param.mode
    anno_file = param.anno_file
    if mode == "gender":
        num_classes = 2
    elif mode == "age":
        num_classes = 101
    elif mode == "emotion":
        num_classes = 7
    else:
        num_classes = 5
    if grayscale:
        input_shape = (64, 64, 1)
    else:
        input_shape = (64, 64, 3)
    images_path = param.img_dir
    log_file_path = '../trained_models/%s_models/%s_model/raining.log' % (
        mode, dataset_name)
    trained_models_path = '../trained_models/%s_models/%s_model/%s_mini_XCEPTION' % (
        mode, dataset_name, mode)
    pretrained_model = param.load_model
    print("-------begin to load data------", input_shape)
    # data loader
    data_loader = DataManager(dataset_name, anno_file)
    ground_truth_data = data_loader.get_data()
    train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split)
    print('Number of training samples:', len(train_keys))
    print('Number of validation samples:', len(val_keys))
    train_image_generator = ImageGenerator(ground_truth_data,
                                           batch_size,
                                           input_shape[:2],
                                           train_keys,
                                           path_prefix=images_path,
                                           grayscale=grayscale)
    val_image_generator = ImageGenerator(ground_truth_data,
                                         batch_size,
                                         input_shape[:2],
                                         val_keys,
                                         path_prefix=images_path,
                                         grayscale=grayscale)

    # model parameters/compilation
    if pretrained_model != None:
        model = load_model(pretrained_model, compile=False)
        print("pretrained model:", model.input_shape)
    else:
        model = mini_XCEPTION(input_shape, num_classes)
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.summary()

    # model callbacks
    early_stop = EarlyStopping('val_acc', patience=patience)
    reduce_lr = ReduceLROnPlateau('val_acc',
                                  factor=0.1,
                                  patience=int(patience),
                                  verbose=1,
                                  min_lr=0.0000001)
    csv_logger = CSVLogger(log_file_path, append=False)
    model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5'
    model_checkpoint = ModelCheckpoint(model_names,
                                       monitor='val_acc',
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=False)
    callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]

    # training model
    print("-----begin to train model----")
    model.fit_generator(
        train_image_generator.flow(),
        steps_per_epoch=int(np.ceil(len(train_keys) / batch_size)),
        epochs=num_epochs,
        verbose=1,
        callbacks=callbacks,
        validation_data=val_image_generator.flow(),
        validation_steps=int(np.ceil(len(val_keys) / batch_size)))