예제 #1
0
def export_train_graph(modelClass, optimizerClass, height, width, channels,
                       classes):
    graph = tf.Graph()
    with graph.as_default():
        # 1. instantiate the model
        model = modelClass(width, height, channels, classes)

        # 4. export train graph
        filename = "%s_%dx%dx%d_%d.meta" % (model.name.lower(), height, width,
                                            channels, classes)
        if os.path.exists(filename):
            print("file %s exists. skipping" % filename)
            return

        # 2. instantiate the optimizer
        optimizer = optimizerClass()

        # 3. instantiate the train wrapper
        trainStrategy = train.ImageClassificationTrainStrategy(
            graph, model, optimizer)

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        tf.add_to_collection("init", init.name)
        tf.add_to_collection("train", trainStrategy.optimize)
        tf.add_to_collection("logits", trainStrategy.logits)
        tf.add_to_collection("summaries", trainStrategy.summary_op)
        tf.add_to_collection("predictions", model.predictions)

        meta = json.dumps({
            "inputs": {
                "batch_image_input": trainStrategy.inputs.name,
                "categorical_labels": trainStrategy.labels.name
            },
            "outputs": {
                "categorical_logits": model.logits.name
            },
            "metrics": {
                "accuracy": trainStrategy.accuracy.name,
                "total_loss": trainStrategy.loss.name
            },
            "parameters": {
                "global_step": trainStrategy.global_step.name,
                "learning_rate": trainStrategy._optimizer.learning_rate.name,
                "momentum": trainStrategy._optimizer.momentum.name
            },
        })

        tf.add_to_collection("meta", meta)

        tf.train.export_meta_graph(filename=filename,
                                   saver_def=saver.as_saver_def())
        print("model exported to ", filename)
예제 #2
0
def generate_train_graph(model_class, optimizer_class,
                         width, height, channels, classes, add_summaries=False):
    graph = tf.Graph()
    with graph.as_default():
        tf.placeholder(tf.bool, name="global_is_training")
        # 1. instantiate the model
        model = model_class(width, height, channels, classes)

        batch_size = tf.placeholder(tf.float32, [], name="batch_size")

        # 2. instantiate the optimizer
        optimizer = optimizer_class()

        # 3. instantiate the train wrapper
        train_strategy = train.ImageClassificationTrainStrategy(
            graph, model, optimizer, batch_size, add_summaries=add_summaries)

    return train_strategy
예제 #3
0
def generate_train_graph(modelClass, optimizerClass, width, height, channels,
                         classes):
    graph = tf.Graph()
    with graph.as_default():
        # 1. instantiate the model
        model = modelClass(width, height, channels, classes)

        # 2. instantiate the optimizer
        optimizer = optimizerClass()

        # 3. instantiate the train wrapper
        trainStrategy = train.ImageClassificationTrainStrategy(
            graph, model, optimizer)

        init = tf.global_variables_initializer()
        tf.add_to_collection("init", init.name)

    return trainStrategy
예제 #4
0
def export_train_graph(model_class, optimizer_class,
                       height, width, channels, classes, output=None):
    graph = tf.Graph()
    with graph.as_default():
        global_is_training = tf.placeholder(tf.bool, name="global_is_training")

        batch_size = tf.placeholder(tf.float32, [], name="batch_size")

        # 1. instantiate the model
        model = model_class(width, height, channels, classes)

        # 2. export train graph
        filename = "%s_%dx%dx%d_%d.meta" % (model.name.lower(),
                                            width,
                                            height,
                                            channels,
                                            classes)

        if output is not None:
            if not output.endswith('/'):
                output = output + '/'
            filename = output + filename

        if os.path.exists(filename):
            print("file %s exists. skipping" % filename)
            return

        # 3. instantiate the optimizer
        optimizer = optimizer_class()

        # 4. instantiate the train wrapper
        train_strategy = train.ImageClassificationTrainStrategy(
            graph,
            model,
            optimizer,
            batch_size,
            add_summaries=True,
        )

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        tf.add_to_collection(ops.GraphKeys.INIT_OP, init.name)
        tf.add_to_collection(ops.GraphKeys.TRAIN_OP, train_strategy.optimize)
        tf.add_to_collection("logits", train_strategy.logits)
        tf.add_to_collection(ops.GraphKeys.SUMMARY_OP, train_strategy.summary_op)
        tf.add_to_collection("predictions", model.predictions)
        tf.add_to_collection(ops.GraphKeys.LOSSES, model.predictions)

        basic_params = {
            "inputs": {"batch_image_input": train_strategy.inputs.name,
                       "categorical_labels": train_strategy.labels.name},
            "outputs": {"categorical_logits": model.logits.name,
                        # This can be removed when TF Java API implements get_operations
                        "layers": ','.join([m.name for m in graph.get_operations()])},
            "metrics": {"accuracy": train_strategy.accuracy.name,
                        "total_loss": train_strategy.loss.name},
            "parameters": {
                "global_step": train_strategy.global_step.name,
                "learning_rate": train_strategy._optimizer.learning_rate.name,
                "momentum": train_strategy._optimizer.momentum.name,
                "batch_size": train_strategy._batch_size.name,
                "global_is_training": global_is_training.name},
        }

        if hasattr(model, 'hidden_dropout'):
            basic_params['parameters']['hidden_dropout'] = model._hidden_dropout.name

        if hasattr(model, 'input_dropout'):
            basic_params['parameters']['input_dropout'] = model._input_dropout.name

        if hasattr(model, 'activations'):
            basic_params['parameters']['activations'] = model._activations.name

        meta = json.dumps(basic_params)

        tf.add_to_collection("meta", meta)

        tf.train.export_meta_graph(filename=filename,
                                   saver_def=saver.as_saver_def())
        print("model exported to ", filename)