def train():
    ROOT = './'
    VGG16_WEIGHT_PATH = './vgg/vgg16_weights.npz'
    DATASET_PATH = os.path.join(ROOT, 'VOC2012/')
    CHECKPOINT_DIR = os.path.join(DATASET_PATH, 'saved_model')

    IMAGE_SHAPE = (512, 512)
    N_CLASSES = 21
    N_EPOCHS = 100
    BATCH_SIZE = 1

    LEARNING_RATE = 1e-5
    DECAY_RATE = 0.95
    DECAY_EPOCH = 10
    DROPOUT_RATE = 0.5

    print('Starting end-to-end training FCN-8s')
    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        allow_growth=True))
    session = tf.compat.v1.InteractiveSession(config=session_config)
    session.as_default()

    # ------------- Load VOC from TFRecord ---------------
    dataset = VOCDataset()
    dataset_train = dataset.load_dataset(DATASET_PATH,
                                         BATCH_SIZE,
                                         is_training=True)
    dataset_val = dataset.load_dataset(DATASET_PATH,
                                       BATCH_SIZE,
                                       is_training=False)

    # ------------- Build fcn model ------------
    fcn = FCN(IMAGE_SHAPE, N_CLASSES, VGG16_WEIGHT_PATH)
    fcn.build_from_vgg()

    learning_rate_fn = learning_rate_with_exp_decay(BATCH_SIZE,
                                                    dataset.n_images['train'],
                                                    DECAY_EPOCH, DECAY_RATE,
                                                    LEARNING_RATE)
    compile_model(fcn, learning_rate_fn)
    fit_model(fcn, N_EPOCHS, BATCH_SIZE, dataset_train, dataset_val,
              CHECKPOINT_DIR, DROPOUT_RATE)
Exemple #2
0
print('Tensorflow version: {}'.format(tf.__version__))

from dataset import VOCDataset
from fcn import FCN

# Killing optional CPU driver warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)

if __name__ == '__main__':
    image_shape = (512, 512)
    n_classes = 21
    vgg16_weights_path = './vgg/vgg16_weights.npz'
    model = FCN(image_shape, n_classes, vgg16_weights_path)
    model.build_from_vgg()

    root_path = './'
    dataset_path = os.path.join(root_path, 'VOC2012/')
    dataset = VOCDataset(augmentation_params=None)

    dataset_val = dataset.load_dataset(dataset_path,
                                       batch_size=8,
                                       is_training=False)
    iterator = tf.data.Iterator.from_structure(dataset_val.output_types,
                                               dataset_val.output_shapes)

    next_batch = iterator.get_next()
    val_init_op = iterator.make_initializer(dataset_val)

    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(