output_regression_count = 6

tf.reset_default_graph()

# For simplicity we just decode jpeg inside tensorflow.
# But one can provide any input obviously.
file_input = tf.placeholder(tf.string, ())
image = tf.image.decode_jpeg(tf.read_file(file_input))
images = tf.expand_dims(image, 0)
images = tf.cast(images, tf.float32) / 128. - 1
images.set_shape((None, None, None, 3))
images = tf.image.resize_images(images, (224, 224))

# Note: arg_scope is optional for inference.
with tf.contrib.slim.arg_scope(mobileNet_v3.training_scope(is_training=False)):
    logits, endpoints = mobileNet_v3.mobilenet(images, output_regression_count)

# Restore using exponential moving average since it produces (1.5-2%) higher
# accuracy
# ema = tf.train.ExponentialMovingAverage(0.999)o
# vars = ema.variables_to_restore()

# load the normalization params if the data is normalized
if normalized_dataset:
    text_file = open(normalized_file_name, "r")
    normalization_params = text_file.read().split()
    text_file.close()

saver = tf.train.Saver()
file_names = os.listdir(dataset_path + validation_path)
file_names = filterImages(file_names)
iter = dataset.make_initializable_iterator()
input_images, labels = iter.get_next()

varsToIgnore = []
if loadFromBaseMobilenet == True:
    #load mobile net
    varsToIgnore = [
        "MobilenetV2/Logits/Conv2d_1c_1x1/biases",
        "MobilenetV2/Logits/Conv2d_1c_1x1/weights"
    ]

#first get the list of names of all ops in the checkpoint
varList = print_tensors_in_checkpoint_file(checkpoint, varsToIgnore)

with tf.contrib.slim.arg_scope(mobileNet_v3.training_scope(is_training=True)):
    logits, endpoints = mobileNet_v3.mobilenet(input_images,
                                               outputRegressionCount)

#getting the variables except for the last layer
#print(endpoints)
vars = []
for name in varList:
    vars = vars + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)

saver = tf.train.Saver(vars)

#add optimizers to mobilenet
trainStep, mseLoss, maeLoss = constructTrainingParams(input_images, logits,
                                                      labels)
summ_loss = tf.summary.scalar("mse_loss", mseLoss)
summ_mae_loss = tf.summary.scalar("mae_loss", maeLoss)
output_regression_count = 3

tf.reset_default_graph()

# For simplicity we just decode jpeg inside tensorflow.
# But one can provide any input obviously.
file_input = tf.placeholder(tf.string, ())
image = tf.image.decode_jpeg(tf.read_file(file_input))
images = tf.expand_dims(image, 0)
images = tf.cast(images, tf.float32) / 128. - 1
images.set_shape((None, None, None, 3))
images = tf.image.resize_images(images, (224, 224))

# Note: arg_scope is optional for inference.
with tf.contrib.slim.arg_scope(mobileNet_v3.training_scope(is_training=False)):
    logits, endpoints = mobileNet_v3.mobilenet(images, 3)

# Restore using exponential moving average since it produces (1.5-2%) higher
# accuracy
# ema = tf.train.ExponentialMovingAverage(0.999)
# vars = ema.variables_to_restore()

# load the normalization params if the data is normalized
if normalized_dataset:
    text_file = open(normalized_file_name, "r")
    normalization_params = text_file.read().split()
    text_file.close()

saver = tf.train.Saver()
file_names = os.listdir(dataset_path + validation_path)
file_names = filterImages(file_names)