Beispiel #1
0
def main():
    config = configparser.RawConfigParser()
    config.read('config.txt')

    experiment_name = config.get('train', 'name')
    if not os.path.exists('./logs/' + experiment_name):
        os.system('mkdir ./logs/' + experiment_name)
    epochs_num = int(config.get('train', 'epochs_num'))
    batch_size = int(config.get('train', 'batch_size'))

    # Load datasets.
    datasets = config.get('train', 'datasets')
    datasets_path = config.get(datasets, 'h5py_save_path')
    height = int(config.get(datasets, 'height'))
    width = int(config.get(datasets, 'width'))
    pad_height = int(config.get(datasets, 'pad_height'))
    pad_width = int(config.get(datasets, 'pad_width'))

    x_train, y_train, masks = Generator(datasets_path, 'train', height, width,
                                        pad_height, pad_width)()
    visualize(group_images(x_train, 4),
              './logs/' + experiment_name + '/train_images.png').show()
    visualize(group_images(y_train, 4),
              './logs/' + experiment_name + '/train_labels.png').show()
    visualize(group_images(masks, 4),
              './logs/' + experiment_name + '/train_masks.png').show()
    y_train = to_categorical(y_train)

    # Build model and save.
    unet = Unet((pad_height, pad_width, 1), 5)
    unet.summary()
    unet_json = unet.to_json()
    open('./logs/' + experiment_name + '/architecture.json',
         'w').write(unet_json)
    plot_model(unet, to_file='./logs/' + experiment_name + '/model.png')

    # Training.
    checkpointer = ModelCheckpoint(filepath='./logs/' + experiment_name +
                                   '/weights.h5',
                                   verbose=1,
                                   monitor='val_loss',
                                   mode='auto',
                                   save_best_only=True)

    unet.fit(
        x_train,
        y_train,
        epochs=epochs_num,
        batch_size=batch_size,
        verbose=1,
        shuffle=True,
        validation_split=0.1,
        #class_weight=(0.5, 1.3),
        callbacks=[checkpointer])
Beispiel #2
0
        def get_lr(epoch):
            w = epoch // 10
            lr = init_lr / (lr_epoch_decay**w)
            if lr < 1e-10:
                lr = 1e-10
            return lr

        callback = LearningRateScheduler(get_lr)
        callbacks.append(callback)

    return callbacks


model = Unet(input_shape, num)
model.summary()
if os.path.exists(mode1_path):  #继续训练
    model.load_weights(mode1_path)
callbacks = make_callbacks()
path = r'F:\cmm\yumi\pic'
train_set, val_set = get_train_val(path, 'tif')
train = train_data(train_set, bitchs)
val = val_data(val_set, bitchs)
alltrain = len(train_set)
allval = len(val_set)
loss_f = loss.mean_iou
metrics = [loss_f]
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=metrics)
A = model.fit_generator(generator=train,
Beispiel #3
0
from unet import Unet

unet = Unet()
unet.summary()