#import system things from tensorflow.examples.tutorials.mnist import input_data # for data import tensorflow as tf import numpy as np import os #import helpers import inference import visualize # prepare data and tf.session mnist = input_data.read_data_sets('MNIST_data', one_hot=False) sess = tf.InteractiveSession() # setup siamese network siamese = inference.siamese(); train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss) saver = tf.train.Saver() tf.initialize_all_variables().run() # if you just want to load a previously trainmodel? load = False model_ckpt = './model.meta' if os.path.isfile(model_ckpt): input_var = None while input_var not in ['yes', 'no']: input_var = input("We found model files. Do you want to load it and continue training [yes/no]?") if input_var == 'yes': load = True # start training
#net_type = 'metricAuto' ## metricOnly: only use metric learning, metricAuto: using autoencoder with metric learning net_type = 'holistic' # net_type = 'metricOnly' batch_train_size = 30 batch_test_size = 300 test_interval = 500 display_interval = 5000 learning_rate = 0.1 momentum = 0.1 # #resulTxt = net_type + '_noTanh.txt' # resulTxt = net_type + 'Appeend.txt' resulTxt = net_type + '.txt' logid = open(resulTxt, 'w') runOption = 1 ##Construct Network ############### siamese = inference.siamese(net_type=net_type) #model_path = './models/' + net_type + '_noTanh_bal0001'# The reconstruction loss without tanh for output, the weight for reconstruction loss is 0.001 model_path = os.path.join('./models', net_type, 'weight') weights_file = './models/metricOnly-195000' #train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss) if net_type == 'metricOnly': train_step = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum).minimize( siamese.loss) train_step_2 = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum).minimize( siamese.loss_2) train_step_1 = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum).minimize( siamese.loss_1) elif net_type == 'holistic':
#mnist = tf.keras.datasets.mnist #(train_images, train_labels),(test_images, test_labels) = mnist.load_data() mnist = input_data.read_data_sets('MNIST_data', one_hot=False) #create session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession() # setup siamese network ################################################################################## lossType = 'inverseWassersteinCL' margin = 5.0 siamese = inference.siamese(sess=sess, margin=margin, batch_size=batch_size, lamb=0.01, lossType=lossType) #normLabel = 1 learning_rate = tf.placeholder(tf.float32, [1]) train_step = tf.train.MomentumOptimizer(learning_rate=learning_rate[0], momentum=momentum).minimize( siamese.loss, var_list=siamese.var_list) # ############################################################################################################## saver = tf.train.Saver() # save results ? init = tf.global_variables_initializer() # initialization sess.run(init) #writer = tf.summary.FileWriter('./sinkhorn_board', sess.graph)
logid.write("Training results for {:10}".format(args.net_type)) logid.close() logid = open(args.logfile, 'a') checkpoint_dir = os.path.join('../models', args.net_type) if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) modelName = 'model' ckpt_status = get_checkpoint(checkpoint_dir) ################################################################################## siamese = inference.siamese( net_type=args.net_type, margin_out=args.margin_out, margin_hid=args.margin_hid, alpha=args.alpha) # net_type: chosing metric only or with holistic loss # siamese = inference_old.siamese() train_step = tf.train.MomentumOptimizer(learning_rate=args.learning_rate, momentum=args.momentum).minimize( siamese.loss) # train_step_2 = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum).minimize(siamese.loss_2) # train_step_1 = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum).minimize(siamese.loss_1) ############################################################################################################## saver = tf.train.Saver(max_to_keep=10000) init = tf.global_variables_initializer() with tf.Session() as sess: print('Initializing all variables')