def export_train_graph(modelClass, optimizerClass, height, width, channels, classes): graph = tf.Graph() with graph.as_default(): # 1. instantiate the model model = modelClass(width, height, channels, classes) # 4. export train graph filename = "%s_%dx%dx%d_%d.meta" % (model.name.lower(), height, width, channels, classes) if os.path.exists(filename): print("file %s exists. skipping" % filename) return # 2. instantiate the optimizer optimizer = optimizerClass() # 3. instantiate the train wrapper trainStrategy = train.ImageClassificationTrainStrategy( graph, model, optimizer) saver = tf.train.Saver() init = tf.global_variables_initializer() tf.add_to_collection("init", init.name) tf.add_to_collection("train", trainStrategy.optimize) tf.add_to_collection("logits", trainStrategy.logits) tf.add_to_collection("summaries", trainStrategy.summary_op) tf.add_to_collection("predictions", model.predictions) meta = json.dumps({ "inputs": { "batch_image_input": trainStrategy.inputs.name, "categorical_labels": trainStrategy.labels.name }, "outputs": { "categorical_logits": model.logits.name }, "metrics": { "accuracy": trainStrategy.accuracy.name, "total_loss": trainStrategy.loss.name }, "parameters": { "global_step": trainStrategy.global_step.name, "learning_rate": trainStrategy._optimizer.learning_rate.name, "momentum": trainStrategy._optimizer.momentum.name }, }) tf.add_to_collection("meta", meta) tf.train.export_meta_graph(filename=filename, saver_def=saver.as_saver_def()) print("model exported to ", filename)
def generate_train_graph(model_class, optimizer_class, width, height, channels, classes, add_summaries=False): graph = tf.Graph() with graph.as_default(): tf.placeholder(tf.bool, name="global_is_training") # 1. instantiate the model model = model_class(width, height, channels, classes) batch_size = tf.placeholder(tf.float32, [], name="batch_size") # 2. instantiate the optimizer optimizer = optimizer_class() # 3. instantiate the train wrapper train_strategy = train.ImageClassificationTrainStrategy( graph, model, optimizer, batch_size, add_summaries=add_summaries) return train_strategy
def generate_train_graph(modelClass, optimizerClass, width, height, channels, classes): graph = tf.Graph() with graph.as_default(): # 1. instantiate the model model = modelClass(width, height, channels, classes) # 2. instantiate the optimizer optimizer = optimizerClass() # 3. instantiate the train wrapper trainStrategy = train.ImageClassificationTrainStrategy( graph, model, optimizer) init = tf.global_variables_initializer() tf.add_to_collection("init", init.name) return trainStrategy
def export_train_graph(model_class, optimizer_class, height, width, channels, classes, output=None): graph = tf.Graph() with graph.as_default(): global_is_training = tf.placeholder(tf.bool, name="global_is_training") batch_size = tf.placeholder(tf.float32, [], name="batch_size") # 1. instantiate the model model = model_class(width, height, channels, classes) # 2. export train graph filename = "%s_%dx%dx%d_%d.meta" % (model.name.lower(), width, height, channels, classes) if output is not None: if not output.endswith('/'): output = output + '/' filename = output + filename if os.path.exists(filename): print("file %s exists. skipping" % filename) return # 3. instantiate the optimizer optimizer = optimizer_class() # 4. instantiate the train wrapper train_strategy = train.ImageClassificationTrainStrategy( graph, model, optimizer, batch_size, add_summaries=True, ) saver = tf.train.Saver() init = tf.global_variables_initializer() tf.add_to_collection(ops.GraphKeys.INIT_OP, init.name) tf.add_to_collection(ops.GraphKeys.TRAIN_OP, train_strategy.optimize) tf.add_to_collection("logits", train_strategy.logits) tf.add_to_collection(ops.GraphKeys.SUMMARY_OP, train_strategy.summary_op) tf.add_to_collection("predictions", model.predictions) tf.add_to_collection(ops.GraphKeys.LOSSES, model.predictions) basic_params = { "inputs": {"batch_image_input": train_strategy.inputs.name, "categorical_labels": train_strategy.labels.name}, "outputs": {"categorical_logits": model.logits.name, # This can be removed when TF Java API implements get_operations "layers": ','.join([m.name for m in graph.get_operations()])}, "metrics": {"accuracy": train_strategy.accuracy.name, "total_loss": train_strategy.loss.name}, "parameters": { "global_step": train_strategy.global_step.name, "learning_rate": train_strategy._optimizer.learning_rate.name, "momentum": train_strategy._optimizer.momentum.name, "batch_size": train_strategy._batch_size.name, "global_is_training": global_is_training.name}, } if hasattr(model, 'hidden_dropout'): basic_params['parameters']['hidden_dropout'] = model._hidden_dropout.name if hasattr(model, 'input_dropout'): basic_params['parameters']['input_dropout'] = model._input_dropout.name if hasattr(model, 'activations'): basic_params['parameters']['activations'] = model._activations.name meta = json.dumps(basic_params) tf.add_to_collection("meta", meta) tf.train.export_meta_graph(filename=filename, saver_def=saver.as_saver_def()) print("model exported to ", filename)