def run():
    num_classes = 3
    image_shape = IMG_SIZE
    data_dir = TRAINING_DIR
    runs_dir = './runs'
    epochs = 100
    batch_size = 1
    learning_rate=1e-5

    net_input = tf.placeholder(
        tf.float32,shape=[None,image_shape[0], image_shape[1],3],
        name="net_input")
    net_output = tf.placeholder(
        tf.float32,shape=[None,image_shape[0], image_shape[1],num_classes],
        name="net_output") 
    
    network = build_mobile_unet(net_input, preset_model = 'MobileUNet-Skip', num_classes=num_classes)

    network = tf.identity(network, name='logits')
    loss = custom_loss(network, net_output)
    opt = tf.train.AdamOptimizer(1e-4).minimize(loss,
        var_list=[var for var in tf.trainable_variables()],
        name='optimizer')

    with tf.Session() as sess:

        # Create function to get batches
        get_batches_fn = helper.gen_batch_function(os.path.join(data_dir), RGB_DIR, SEG_DIR, image_shape)

        init_op = tf.global_variables_initializer()

        saver = tf.train.Saver()

        # Runs training
        sess.run(init_op)
        train_nn(sess, epochs, batch_size, get_batches_fn, opt, loss, net_input,
                 net_output, learning_rate)

        # Save the trained model
        today = datetime.datetime.now().strftime("%Y-%m-%d-%H%M")
        save_dir = os.path.join(SAVE_MODEL_DIR, today)
        helper.save_model(sess, net_input, network, save_dir)

        print("SavedModel saved at {}".format(save_dir))

        test_dir = TEST_DIR
        helper3.save_inference_samples(runs_dir, test_dir, sess, image_shape,
                                      network, net_input)
Пример #2
0
    def inference (self, net_input, num_classes, is_training):
        if FLAGS.patch_slim:
            fuck_slim.patch(is_training)
        network = None
        init_fn = None
        if FLAGS.net == "FC-DenseNet56" or FLAGS.net == "FC-DenseNet67" or FLAGS.net == "FC-DenseNet103":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_fc_densenet(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "RefineNet-Res50" or FLAGS.net == "RefineNet-Res101" or FLAGS.net == "RefineNet-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # RefineNet requires pre-trained ResNet weights
                network, init_fn = build_refinenet(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "FRRN-A" or FLAGS.net == "FRRN-B":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_frrn(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "Encoder-Decoder" or FLAGS.net == "Encoder-Decoder-Skip":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_encoder_decoder(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "MobileUNet" or FLAGS.net == "MobileUNet-Skip":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_mobile_unet(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "PSPNet-Res50" or FLAGS.net == "PSPNet-Res101" or FLAGS.net == "PSPNet-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # Image size is required for PSPNet
            # PSPNet requires pre-trained ResNet weights
                network, init_fn = build_pspnet(net_input, label_size=[args.crop_height, args.crop_width], preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "GCN-Res50" or FLAGS.net == "GCN-Res101" or FLAGS.net == "GCN-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # GCN requires pre-trained ResNet weights
                network, init_fn = build_gcn(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "DeepLabV3-Res50" or FLAGS.net == "DeepLabV3-Res101" or FLAGS.net == "DeepLabV3-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # DeepLabV requires pre-trained ResNet weights
                network, init_fn = build_deeplabv3(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "DeepLabV3_plus-Res50" or FLAGS.net == "DeepLabV3_plus-Res101" or FLAGS.net == "DeepLabV3_plus-Res152":
            # DeepLabV3+ requires pre-trained ResNet weights
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network, init_fn = build_deeplabv3_plus(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "AdapNet":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_adaptnet(net_input, num_classes=num_classes)
        else:
            raise ValueError("Error: the model %d is not available. Try checking which models are available using the command python main.py --help")

        self.init_fn = init_fn
        return network
Пример #3
0
elif args.model == "RefineNet-Res50" or args.model == "RefineNet-Res101" or args.model == "RefineNet-Res152":
    # RefineNet requires pre-trained ResNet weights
    network, init_fn = build_refinenet(net_input,
                                       preset_model=args.model,
                                       num_classes=num_classes)
elif args.model == "FRRN-A" or args.model == "FRRN-B":
    network = build_frrn(net_input,
                         preset_model=args.model,
                         num_classes=num_classes)
elif args.model == "Encoder-Decoder" or args.model == "Encoder-Decoder-Skip":
    network = build_encoder_decoder(net_input,
                                    preset_model=args.model,
                                    num_classes=num_classes)
elif args.model == "MobileUNet" or args.model == "MobileUNet-Skip":
    network = build_mobile_unet(net_input,
                                preset_model=args.model,
                                num_classes=num_classes)
elif args.model == "PSPNet-Res50" or args.model == "PSPNet-Res101" or args.model == "PSPNet-Res152":
    # Image size is required for PSPNet
    # PSPNet requires pre-trained ResNet weights
    network, init_fn = build_pspnet(
        net_input,
        label_size=[args.crop_height, args.crop_width],
        preset_model=args.model,
        num_classes=num_classes)
elif args.model == "GCN-Res50" or args.model == "GCN-Res101" or args.model == "GCN-Res152":
    # GCN requires pre-trained ResNet weights
    network, init_fn = build_gcn(net_input,
                                 preset_model=args.model,
                                 num_classes=num_classes)
elif args.model == "DeepLabV3-Res50" or args.model == "DeepLabV3-Res101" or args.model == "DeepLabV3-Res152":
Пример #4
0
def run():
    num_classes = 3
    image_shape = IMG_SIZE
    runs_dir = './runs'
    epochs = args.num_epochs
    batch_size = args.batch_size
    learning_rate=1e-5

    config = tf.ConfigProto()
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as sess:

        if args.load_model is not None:
            meta_graph_def = tf.saved_model.loader.load(sess,
                                            [tf.saved_model.tag_constants.SERVING],
                                            args.load_model)
            graph = tf.get_default_graph()
            net_input = graph.get_tensor_by_name(args.load_net_input_name)
            net_output = graph.get_tensor_by_name(args.load_net_output_name)
            network = graph.get_tensor_by_name(args.load_logits_name)
            loss = graph.get_tensor_by_name(args.load_loss_name)
            opt = graph.get_operation_by_name(args.load_optimizer_name)
        else:
            net_input = tf.placeholder(
                tf.float32,shape=[None,image_shape[0], image_shape[1],3],
                name="net_input")

            
            network = build_mobile_unet(net_input, preset_model = 'MobileUNet-Skip', num_classes=num_classes)

            network = tf.identity(network, name='logits')

            net_output = tf.placeholder(
                tf.float32,shape=[None,image_shape[0], image_shape[1], num_classes],
                name="net_output") 

            loss = custom_loss(network, net_output)
            opt = tf.train.AdamOptimizer(1e-4).minimize(
                loss,
                var_list=[var for var in tf.trainable_variables()],
                name='optimizer')

        # Create function to get batches
        get_batches_fn = helper.gen_batch_function(TRAINING_DIRS, RGB_DIR, SEG_DIR, args)

        init_op = tf.global_variables_initializer()

        # Prepares saver and loads checkpoint if any found.
        saver = tf.train.Saver(max_to_keep=1)
        today = datetime.datetime.now().strftime("%Y-%m-%d-%H%M")
        save_dir = os.path.join(SAVE_MODEL_DIR, today)
        checkpoint_path = os.path.join(SAVE_MODEL_DIR, '{}-ckpt'.format(today), 'model.ckpt')

        # Runs training
        sess.run(init_op)

        if args.load_model is not None:
            load_checkpoint_path = os.path.join('{}-ckpt'.format(args.load_model), 'model.ckpt')
            if os.path.exists('{}-ckpt'.format(args.load_model)):
                print("Loads checkpoint", load_checkpoint_path)
                saver.restore(sess, load_checkpoint_path)
            else:
                print("Checkpoint", load_checkpoint_path, "not found. Restart training instead.")

        train_nn(sess, epochs, batch_size, get_batches_fn, opt, loss, net_input,
                 net_output, learning_rate, saver, checkpoint_path, network, save_dir)
Пример #5
0
def buildNetwork(model, net_input, num_class):
    # Get the selected model.
    # Some of them require pre-trained ResNet
    if "Res50" in model and not os.path.isfile("models/resnet_v2_50.ckpt"):
        download_checkpoints("ResnetV2", "50")
    if "Res101" in model and not os.path.isfile("models/resnet_v2_101.ckpt"):
        utils.download_checkpoints("ResnetV2", "101")
    if "Res152" in model and not os.path.isfile("models/resnet_v2_152.ckpt"):
        utils.download_checkpoints("ResnetV2", "152")

    network = None
    init_fn = None
    if model == "FC-DenseNet56" or model == "FC-DenseNet67" or model == "FC-DenseNet103":
        network = build_fc_densenet(net_input,
                                    preset_model=model,
                                    num_classes=num_class)
    elif model == "RefineNet-Res50" or model == "RefineNet-Res101" or model == "RefineNet-Res152":
        # RefineNet requires pre-trained ResNet weights
        network, init_fn = build_refinenet(net_input,
                                           preset_model=model,
                                           num_classes=num_class)
    elif model == "FRRN-A" or model == "FRRN-B":
        network = build_frrn(net_input,
                             preset_model=model,
                             num_classes=num_class)
    elif model == "Encoder-Decoder" or model == "Encoder-Decoder-Skip":
        network = build_encoder_decoder(net_input,
                                        preset_model=model,
                                        num_classes=num_class)
    elif model == "MobileUNet" or model == "MobileUNet-Skip":
        network = build_mobile_unet(net_input,
                                    preset_model=model,
                                    num_classes=num_class)
    elif model == "PSPNet-Res50" or model == "PSPNet-Res101" or model == "PSPNet-Res152":
        # Image size is required for PSPNet
        # PSPNet requires pre-trained ResNet weights
        network, init_fn = build_pspnet(
            net_input,
            label_size=[args.crop_height, args.crop_width],
            preset_model=model,
            num_classes=num_class)
    elif model == "GCN-Res50" or model == "GCN-Res101" or model == "GCN-Res152":
        # GCN requires pre-trained ResNet weights
        network, init_fn = build_gcn(net_input,
                                     preset_model=model,
                                     num_classes=num_class)
    elif model == "DeepLabV3-Res50" or model == "DeepLabV3-Res101" or model == "DeepLabV3-Res152":
        # DeepLabV requires pre-trained ResNet weights
        network, init_fn = build_deeplabv3(net_input,
                                           preset_model=model,
                                           num_classes=num_class)
    elif model == "DeepLabV3_plus-Res50" or model == "DeepLabV3_plus-Res101" or model == "DeepLabV3_plus-Res152":
        # DeepLabV3+ requires pre-trained ResNet weights
        network, init_fn = build_deeplabv3_plus(net_input,
                                                preset_model=model,
                                                num_classes=num_class)
    elif model == "AdapNet":
        network = build_adaptnet(net_input, num_classes=num_class)
    elif model == "custom":
        network = build_custom(net_input, num_class)
    else:
        raise ValueError(
            "Error: the model %d is not available. Try checking which models are available using the command python main.py --help"
        )
    return network, init_fn
Пример #6
0
def run():
    num_classes = 3
    image_shape = IMG_SIZE
    data_dir = TRAINING_DIR
    runs_dir = './runs'
    epochs = args.num_epochs
    batch_size = args.batch_size
    learning_rate = 1e-5

    config = tf.ConfigProto()
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as sess:

        if args.load_model is not None:
            meta_graph_def = tf.saved_model.loader.load(
                sess, [tf.saved_model.tag_constants.SERVING], args.load_model)
            graph = tf.get_default_graph()
            net_input = graph.get_tensor_by_name(args.load_net_input_name)
            net_output = graph.get_tensor_by_name(args.load_net_output_name)
            network = graph.get_tensor_by_name(args.load_logits_name)
            loss = graph.get_tensor_by_name(args.load_loss_name)
            opt = graph.get_operation_by_name(args.load_optimizer_name)
        else:
            net_input = tf.placeholder(
                tf.float32,
                shape=[None, image_shape[0], image_shape[1], 3],
                name="net_input")

            network = build_mobile_unet(net_input,
                                        preset_model='MobileUNet-Skip',
                                        num_classes=num_classes)

            network = tf.identity(network, name='logits')

            net_output = tf.placeholder(
                tf.float32,
                shape=[None, image_shape[0], image_shape[1], num_classes],
                name="net_output")

            loss = custom_loss(network, net_output)
            opt = tf.train.AdamOptimizer(1e-4).minimize(
                loss,
                var_list=[var for var in tf.trainable_variables()],
                name='optimizer')

        # Create function to get batches
        get_batches_fn = helper.gen_batch_function(os.path.join(data_dir),
                                                   RGB_DIR, SEG_DIR, args)

        init_op = tf.global_variables_initializer()

        saver = tf.train.Saver()

        # Runs training
        sess.run(init_op)
        train_nn(sess, epochs, batch_size, get_batches_fn, opt, loss,
                 net_input, net_output, learning_rate)

        # Save the trained model
        today = datetime.datetime.now().strftime("%Y-%m-%d-%H%M")
        save_dir = os.path.join(SAVE_MODEL_DIR, today)
        helper.save_model(sess, net_input, network, save_dir)

        print("SavedModel saved at {}".format(save_dir))

        test_dir = TEST_DIR
        helper.save_inference_samples(runs_dir, test_dir, sess, image_shape,
                                      network, net_input)
Пример #7
0
def build_model(model_name,
                net_input,
                num_classes,
                frontend="ResNet101",
                is_training=True):
    # Get the selected model.
    # Some of them require pre-trained ResNet

    print("Preparing the model ...")

    if model_name not in SUPPORTED_MODELS:
        raise ValueError(
            "The model you selelect is not supported. The following models are currently supported: {0}"
            .format(SUPPORTED_MODELS))

    if frontend not in SUPPORTED_FRONTENDS:
        raise ValueError(
            "The frontend you selelect is not supported. The following models are currently supported: {0}"
            .format(SUPPORTED_FRONTENDS))

    if "ResNet50" == frontend and not os.path.isfile(
            "models/resnet_v2_50.ckpt"):
        download_checkpoints("ResNet50")
    if "ResNet101" == frontend and not os.path.isfile(
            "models/resnet_v2_101.ckpt"):
        download_checkpoints("ResNet101")
    if "ResNet152" == frontend and not os.path.isfile(
            "models/resnet_v2_152.ckpt"):
        download_checkpoints("ResNet152")
    if "MobileNetV2" == frontend and not os.path.isfile(
            "models/mobilenet_v2_1.4_224.ckpt.data-00000-of-00001"):
        download_checkpoints("MobileNetV2")
    if "InceptionV4" == frontend and not os.path.isfile(
            "models/inception_v4.ckpt"):
        download_checkpoints("InceptionV4")

    network = None
    init_fn = None
    if model_name == "FC-DenseNet56" or model_name == "FC-DenseNet67" or model_name == "FC-DenseNet103":
        network = build_fc_densenet(net_input,
                                    preset_model=model_name,
                                    num_classes=num_classes)
    elif model_name == "RefineNet":
        # RefineNet requires pre-trained ResNet weights
        network, init_fn = build_refinenet(net_input,
                                           preset_model=model_name,
                                           frontend=frontend,
                                           num_classes=num_classes,
                                           is_training=is_training)
    elif model_name == "FRRN-A" or model_name == "FRRN-B":
        network = build_frrn(net_input,
                             preset_model=model_name,
                             num_classes=num_classes)
    elif model_name == "Encoder-Decoder" or model_name == "Encoder-Decoder-Skip":
        network = build_encoder_decoder(net_input,
                                        preset_model=model_name,
                                        num_classes=num_classes)
    elif model_name == "MobileUNet" or model_name == "MobileUNet-Skip":
        network = build_mobile_unet(net_input,
                                    preset_model=model_name,
                                    num_classes=num_classes)
    elif model_name == "PSPNet":
        # Image size is required for PSPNet
        # PSPNet requires pre-trained ResNet weights
        network, init_fn = build_pspnet(
            net_input,
            label_size=[args.crop_height, args.crop_width],
            preset_model=model_name,
            frontend=frontend,
            num_classes=num_classes,
            is_training=is_training)
    elif model_name == "GCN":
        # GCN requires pre-trained ResNet weights
        network, init_fn = build_gcn(net_input,
                                     preset_model=model_name,
                                     frontend=frontend,
                                     num_classes=num_classes,
                                     is_training=is_training)
    elif model_name == "DeepLabV3":
        # DeepLabV requires pre-trained ResNet weights
        network, init_fn = build_deeplabv3(net_input,
                                           preset_model=model_name,
                                           frontend=frontend,
                                           num_classes=num_classes,
                                           is_training=is_training)
    elif model_name == "DeepLabV3_plus":
        # DeepLabV3+ requires pre-trained ResNet weights
        network, init_fn = build_deeplabv3_plus(net_input,
                                                preset_model=model_name,
                                                frontend=frontend,
                                                num_classes=num_classes,
                                                is_training=is_training)
    elif model_name == "AdapNet":
        network = build_adaptnet(net_input, num_classes=num_classes)
    elif model_name == "custom":
        network = build_custom(net_input, num_classes)
    else:
        raise ValueError(
            "Error: the model %d is not available. Try checking which models are available using the command python main.py --help"
        )

    return network, init_fn