Esempio n. 1
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        pasc_still = cnn_db_loader.PaSC_still_loader(outputfolder=db_dir,
                                                     db_base=PaSC_still_BASE)
        pasc_video = cnn_db_loader.PaSC_video_loader(outputfolder=db_dir,
                                                     db_base=PaSC_video_BASE)
        casia = cnn_db_loader.CASIA_webface_loader(outputfolder=db_dir,
                                                   db_base=CASIA_BASE)
        pasc_still.set_all_as_train()
        casia.set_all_as_train()
        pasc_video.split_train_eval(train_proportion=0.8)
        db_loader = cnn_db_loader.Aggregator(pasc_video, pasc_still, casia)
        #db_loader = cnn_db_loader.Aggregator(pasc_still)

        num_batches_per_epoch = len(
            db_loader.examples_train) / FLAGS.batch_size

        images_list, labels_list = db_loader.get_training_multi_image_and_label_lists(
        )

        #images = [0]*cnn_db_loader.NUMBER_IMAGES
        output = tf_utils.inputs_multi(images_list,
                                       labels_list,
                                       FLAGS.batch_size,
                                       db_loader.get_mean_image_path(),
                                       png_with_alpha=True,
                                       image_size=256)
        #output = tf_utils.inputs_multi(images_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path(), png_with_alpha=False, image_size=512)
        images = output[:cnn_db_loader.NUMBER_IMAGES]
        labels = output[-1]
        #print (output)
        #for i in range(cnn_db_loader.NUMBER_IMAGES):
        #	images[i], labels = tf_utils.inputs([image[i] for image in images_list], labels_list, FLAGS.batch_size, db_loader.get_mean_image_path())

        confs = [0] * cnn_db_loader.NUMBER_IMAGES
        with tf.variable_scope("confidence_estimation") as scope:
            for i in range(cnn_db_loader.NUMBER_IMAGES):
                confs[i] = cnn_tf_graphs.confidence_cnn23(images[i],
                                                          input_size=256)
                #confs[i] = cnn_tf_graphs.confidence_cnn4(images[i], input_size=512)
                scope.reuse_variables()
                #tf.get_variable_scope().reuse_variables()

        merging_input_list = [[images[i], confs[i]]
                              for i in range(cnn_db_loader.NUMBER_IMAGES)]
        merged_image = cnn_tf_graphs.merge_isomaps_softmax(merging_input_list)

        merged_image = tf.slice(merged_image, [0, 0, 0, 0], [-1, -1, -1, 3])

        # Build a Graph that computes the logits predictions from the inference model.
        logits, _ = cnn_tf_graphs.inference(network="alex",
                                            mode=learn.ModeKeys.TRAIN,
                                            batch_size=FLAGS.batch_size,
                                            num_classes=db_loader.number_ids,
                                            input_image_tensor=merged_image,
                                            image_size=256)
        #logits, _ = cnn_tf_graphs.inference(network="alex", mode=learn.ModeKeys.TRAIN, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=merged_image, image_size=512)

        # Calculate loss.
        #loss = cnn_tf_graphs.l2_loss(logits, labels)
        loss = cnn_tf_graphs.softmax_loss(logits, labels, db_loader.number_ids)

        top_k_op = tf.nn.in_top_k(logits, labels, 1)
        sum_correct = tf.reduce_sum(tf.cast(top_k_op, tf.float32))
        accuracy = tf.divide(tf.multiply(sum_correct, tf.constant(100.0)),
                             tf.constant(float(FLAGS.batch_size)))
        #accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(tf.argmax(logits,1), tf.argmax(labels, 1))

        lr = tf.constant(INITIAL_LEARNING_RATE, tf.float32)
        tf.summary.scalar('learning_rate', lr)
        tf.summary.scalar('momentum', MOMENTUM)
        tf.summary.scalar('batch_size', FLAGS.batch_size)
        tf.summary.scalar('accuracy', accuracy)

        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,
                                               momentum=MOMENTUM)
        #optimizer=tf.train.AdadeltaOptimizer(learning_rate=lr)

        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=tf.contrib.framework.get_global_step(),
            learning_rate=lr,
            optimizer=optimizer,
            variables=None)
        #variables=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "confidence_estimation"))

        logging_hook = tf.train.LoggingTensorHook(tensors={
            'step':
            tf.contrib.framework.get_global_step(),
            'loss':
            loss,
            'lr':
            lr,
            'acc':
            accuracy
        },
                                                  every_n_iter=100)

        #saver = tf.train.Saver(var_list=None, keep_checkpoint_every_n_hours=1)
        saver = tf.train.Saver(var_list=None, max_to_keep=None)
        if RESTORE:
            classification_network_variables = [
                var
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                if var not in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                "confidence_estimation")
            ]
            all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            print('all vars:', len(all_variables))
            conf_conv3_variables = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES,
                "confidence_estimation/deconv3/")
            conf_conv3_optimize = tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES,
                "OptimizeLoss/confidence_estimation/deconv3/")
            print('conv3 vars',
                  len(conf_conv3_variables + conf_conv3_optimize))
            good_ones = [
                var for var in all_variables
                if (var not in conf_conv3_variables) and (
                    var not in conf_conv3_optimize)
            ]
            #print (type(all_variables))
            #restorer = tf.train.Saver(var_list=classification_network_variables, max_to_keep=None)
            restorer = tf.train.Saver(var_list=good_ones, max_to_keep=None)

        class _LearningRateSetterHook(tf.train.SessionRunHook):
            """Sets learning_rate based on global step."""
            def begin(self):
                self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**6
                #print(self.num_batches_per_epoch)

            def before_run(self, run_context):
                return tf.train.SessionRunArgs(
                    tf.contrib.framework.get_global_step(
                    ),  # Asks for global step value.
                    feed_dict={lr: self._lrn_rate})  # Sets learning rate

            def after_run(self, run_context, run_values):
                train_step = run_values.results
                self._lrn_rate = INITIAL_LEARNING_RATE
                #training_epoch = int(train_step/num_batches_per_epoch)
                #self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**int(train_step/num_batches_per_epoch/2.7)
                if train_step < 1.5 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE
                elif train_step < 3.0 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**1
                elif train_step < 4.5 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**2
                elif train_step < 6.0 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**3
                elif train_step < 7.5 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**4
                else:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**5

        config = tf.ConfigProto(
            allow_soft_placement=False,
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredTrainingSession(
                is_chief=True,
                checkpoint_dir=train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=train_dir,
                        save_steps=num_batches_per_epoch,
                        saver=saver), logging_hook,
                    _LearningRateSetterHook()
                ],
                config=config,
                save_checkpoint_secs=3600) as mon_sess:
            #saver.restore(mon_sess,'/user/HS204/m09113/my_project_folder/cnn_experiments/28/train_first_part/model.ckpt-21575')
            if RESTORE:
                restorer.restore(
                    mon_sess,
                    tf.train.latest_checkpoint(train_dir + '_restore'))
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
def eval(saved_model_path, db_loader):
	"""Train CIFAR-10 for a number of steps."""
	with tf.Graph().as_default():
		#global_step = tf.contrib.framework.get_or_create_global_step()

		if cnn_db_loader.NUMBER_ALPHAS>0:
			image_list, alphas_list, labels_list = db_loader.get_eval_image_alphas_and_label_lists()

			images, alphas, labels = tf_utils.inputs_with_alphas(image_list, alphas_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path())

			# Build a Graph that computes the logits predictions from the inference model.
			logits, _ = cnn_tf_graphs.inference(network="alex_with_alpha", mode=learn.ModeKeys.EVAL, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=images, input_alpha_tensor=alphas)
			
		elif cnn_db_loader.NUMBER_IMAGES==1 and cnn_db_loader.NUMBER_ALPHAS==0 and cnn_db_loader.NUMBER_XYZ==0:
			image_list, labels_list = db_loader.get_eval_image_and_label_lists()

			images, labels = tf_utils.inputs(image_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path(), image_size=256)

			# Build a Graph that computes the logits predictions from the inference model.
			logits, _ = cnn_tf_graphs.inference(network="alex", mode=learn.ModeKeys.EVAL, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=images, image_size=256)

		elif cnn_db_loader.NUMBER_IMAGES==0 and cnn_db_loader.NUMBER_ALPHAS==0 and cnn_db_loader.NUMBER_XYZ==1:
			image_list, labels_list = db_loader.get_eval_xyz_and_label_lists()

			images, labels = tf_utils.inputs(image_list, labels_list, FLAGS.batch_size, db_loader.get_mean_xyz_path())

			# Build a Graph that computes the logits predictions from the inference model.
			logits, _ = cnn_tf_graphs.inference(network="alex", mode=learn.ModeKeys.EVAL, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=images)

		elif cnn_db_loader.NUMBER_ALPHAS == 0 and cnn_db_loader.NUMBER_IMAGES == 1 and cnn_db_loader.NUMBER_XYZ == 1:
			image_list, xyz_list, labels_list = db_loader.get_eval_image_xyz_and_label_lists()

			isomap_stacks, labels = tf_utils.inputs_stack_image_and_xyz(image_list, xyz_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path(), db_loader.get_mean_xyz_path())

			# Build a Graph that computes the logits predictions from the inference model.
			logits, _ = cnn_tf_graphs.inference(network="dcnn", mode=learn.ModeKeys.EVAL, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=isomap_stacks)

		elif cnn_db_loader.NUMBER_ALPHAS == 0 and cnn_db_loader.NUMBER_IMAGES > 1 and cnn_db_loader.NUMBER_XYZ == 0:
			images_list, labels_list = db_loader.get_eval_multi_image_and_label_lists()

			#images = [0]*cnn_db_loader.NUMBER_IMAGES
			output = tf_utils.inputs_multi(images_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path(), png_with_alpha=True, image_size=256)
			#output = tf_utils.inputs_multi(images_list, labels_list, FLAGS.batch_size, db_loader.get_mean_image_path(), png_with_alpha=False, image_size=512)
			images = output[:cnn_db_loader.NUMBER_IMAGES]
			labels  = output[-1]

			confs  = [0]*cnn_db_loader.NUMBER_IMAGES
			with tf.variable_scope("confidence_estimation") as scope:
				for i in range(cnn_db_loader.NUMBER_IMAGES):
					confs[i] = cnn_tf_graphs.confidence_cnn23(images[i], input_size=256)
					scope.reuse_variables()

			merging_input_list = [[images[i], confs[i]] for i in range(cnn_db_loader.NUMBER_IMAGES)]
			merged_image = cnn_tf_graphs.merge_isomaps_softmax(merging_input_list)

			merged_image = tf.slice(merged_image,[0,0,0,0],[-1,-1,-1,3])

			# Build a Graph that computes the logits predictions from the inference model.
			logits, _ = cnn_tf_graphs.inference(network="alex", mode=learn.ModeKeys.EVAL, batch_size=FLAGS.batch_size, num_classes=db_loader.number_ids, input_image_tensor=merged_image, image_size=256)

		
		#correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
		#batch_accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
		top_k_op = tf.nn.in_top_k(logits, labels, 1)



		saver = tf.train.Saver()

		config = tf.ConfigProto( allow_soft_placement=False, log_device_placement=FLAGS.log_device_placement)
		config.gpu_options.allow_growth = True

		with tf.Session(config=config) as sess:
			print('restore model')
			saver.restore(sess, saved_model_path)

			print('we have',len(db_loader.examples_eval), 'images to evaluate')

			# Start the queue runners.
			coord = tf.train.Coordinator()
			try:
				threads = []
				for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
					threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))

				num_iter = int(math.ceil(len(db_loader.examples_eval) / FLAGS.batch_size))
				true_count = 0  # Counts the number of correct predictions.
				total_sample_count = num_iter * FLAGS.batch_size
				step = 0
				print('so have to run this',num_iter, 'times. Let\'s start! iter: ', end=' ')
				sys.stdout.flush()
				while step < num_iter and not coord.should_stop():
					predictions = sess.run([top_k_op])
					true_count += np.sum(predictions)
					step += 1
					print(step, end=' ')
					sys.stdout.flush()

				# Compute precision @ 1.
				precision = true_count / total_sample_count
			except Exception as e:  # pylint: disable=broad-except
				coord.request_stop(e)
			
			coord.request_stop()
			coord.join(threads, stop_grace_period_secs=10)
	return precision
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        #cnn_db_loader.IMAGE_FILE_ENDING = '/*'
        cnn_db_loader.IMAGE_FILE_ENDING = '/*isomap.png'
        pasc_still = cnn_db_loader.PaSC_still_loader(outputfolder=db_dir,
                                                     db_base=PaSC_still_BASE)
        #pasc_still.analyse_isomaps()
        #pasc_still.remove_bad_isomaps()
        pasc_video = cnn_db_loader.PaSC_video_loader(outputfolder=db_dir,
                                                     db_base=PaSC_video_BASE)
        #pasc_video.analyse_isomaps()
        #pasc_video.remove_bad_isomaps()

        #cnn_db_loader.IMAGE_FILE_ENDING = '/*.jpg'
        casia = cnn_db_loader.CASIA_webface_loader(outputfolder=db_dir,
                                                   db_base=CASIA_BASE)
        #casia.analyse_isomaps()
        #casia.remove_bad_isomaps()

        cnn_db_loader.IMAGE_FILE_ENDING = '/*'
        pasc_still_merges = cnn_db_loader.PaSC_still_loader(
            outputfolder=experiment_dir + '/db_input_merges/',
            db_base=
            '/user/HS204/m09113/my_project_folder/PaSC/still/random_merges_256_conf13/'
        )
        pasc_video_merges = cnn_db_loader.PaSC_video_loader(
            outputfolder=experiment_dir + '/db_input_merges/',
            db_base=
            '/user/HS204/m09113/my_project_folder/PaSC/video/random_merges_256_conf13/'
        )
        casia_merges = cnn_db_loader.CASIA_webface_loader(
            outputfolder=experiment_dir + '/db_input_merges/',
            db_base=
            '/user/HS204/m09113/my_project_folder/CASIA_webface/random_merges_256_conf13/'
        )

        pasc_still_merges.set_all_as_train()
        pasc_video_merges.split_train_eval(train_proportion=0.8)
        casia_merges.set_all_as_train()

        pasc_still.set_all_as_train()
        casia.set_all_as_train()
        pasc_video.split_train_eval(train_proportion=0.8)
        #db_loader = cnn_db_loader.Aggregator(pasc_video, pasc_still, casia)
        db_loader = cnn_db_loader.Aggregator(pasc_still, pasc_still_merges,
                                             pasc_video, pasc_video_merges,
                                             casia, casia_merges)
        #db_loader = cnn_db_loader.Aggregator(casia)
        #db_loader.make_sure_nothings_empty()

        num_batches_per_epoch = len(
            db_loader.examples_train) / FLAGS.batch_size

        if cnn_db_loader.NUMBER_ALPHAS > 0 and cnn_db_loader.NUMBER_XYZ == 0:
            image_list, alphas_list, labels_list = db_loader.get_training_image_alphas_and_label_lists(
            )

            images, alphas, labels = tf_utils.inputs_with_alphas(
                image_list, alphas_list, labels_list, FLAGS.batch_size,
                db_loader.get_mean_image_path())

            # Build a Graph that computes the logits predictions from the inference model.
            logits, _ = cnn_tf_graphs.inference(
                network="alex_with_alpha",
                mode=learn.ModeKeys.TRAIN,
                batch_size=FLAGS.batch_size,
                num_classes=db_loader.number_ids,
                input_image_tensor=images,
                input_alpha_tensor=alphas)

        elif cnn_db_loader.NUMBER_ALPHAS == 0 and cnn_db_loader.NUMBER_IMAGES == 1 and cnn_db_loader.NUMBER_XYZ == 0:
            image_list, labels_list = db_loader.get_training_image_and_label_lists(
            )

            #			image_name_tensor = tf.placeholder(tf.string)
            #			image_contents = tf.read_file(image_name_tensor)
            #			image = tf.image.decode_image(image_contents, channels=3)
            #			init_op = tf.initialize_all_tables()
            #			with tf.Session() as sess:
            #				sess.run(init_op)
            #				for image_name in image_list[210000:]:
            #					print (image_name)
            #					tmp = sess.run(image, feed_dict={image_name_tensor: image_name})

            images, labels = tf_utils.inputs(image_list,
                                             labels_list,
                                             FLAGS.batch_size,
                                             db_loader.get_mean_image_path(),
                                             image_size=256)

            # Build a Graph that computes the logits predictions from the inference model.
            logits, _ = cnn_tf_graphs.inference(
                network="alex",
                mode=learn.ModeKeys.TRAIN,
                batch_size=FLAGS.batch_size,
                num_classes=db_loader.number_ids,
                input_image_tensor=images,
                image_size=256)

        elif cnn_db_loader.NUMBER_ALPHAS == 0 and cnn_db_loader.NUMBER_IMAGES == 0 and cnn_db_loader.NUMBER_XYZ == 1:
            image_list, labels_list = db_loader.get_training_xyz_and_label_lists(
            )

            images, labels = tf_utils.inputs(image_list, labels_list,
                                             FLAGS.batch_size,
                                             db_loader.get_mean_xyz_path())

            # Build a Graph that computes the logits predictions from the inference model.
            logits, _ = cnn_tf_graphs.inference(
                network="alex",
                mode=learn.ModeKeys.TRAIN,
                batch_size=FLAGS.batch_size,
                num_classes=db_loader.number_ids,
                input_image_tensor=images)

        elif cnn_db_loader.NUMBER_ALPHAS == 0 and cnn_db_loader.NUMBER_IMAGES == 1 and cnn_db_loader.NUMBER_XYZ == 1:
            image_list, xyz_list, labels_list = db_loader.get_training_image_xyz_and_label_lists(
            )

            isomap_stacks, labels = tf_utils.inputs_stack_image_and_xyz(
                image_list, xyz_list, labels_list, FLAGS.batch_size,
                db_loader.get_mean_image_path(), db_loader.get_mean_xyz_path())

            # Build a Graph that computes the logits predictions from the inference model.
            logits, _ = cnn_tf_graphs.inference(
                network="alex",
                mode=learn.ModeKeys.TRAIN,
                batch_size=FLAGS.batch_size,
                num_classes=db_loader.number_ids,
                input_image_tensor=isomap_stacks)

        #exit(0)
        # Calculate loss.
        #loss = cnn_tf_graphs.l2_loss(logits, labels)
        loss = cnn_tf_graphs.softmax_loss(logits, labels, db_loader.number_ids)

        top_k_op = tf.nn.in_top_k(logits, labels, 1)
        sum_correct = tf.reduce_sum(tf.cast(top_k_op, tf.float32))
        accuracy = tf.divide(tf.multiply(sum_correct, tf.constant(100.0)),
                             tf.constant(float(FLAGS.batch_size)))
        #accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(tf.argmax(logits,1), tf.argmax(labels, 1))

        lr = tf.constant(INITIAL_LEARNING_RATE, tf.float32)
        tf.summary.scalar('learning_rate', lr)
        tf.summary.scalar('momentum', MOMENTUM)
        tf.summary.scalar('batch_size', FLAGS.batch_size)
        tf.summary.scalar('accuracy', accuracy)

        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,
                                               momentum=MOMENTUM)
        #optimizer=tf.train.AdadeltaOptimizer(learning_rate=lr)

        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=tf.contrib.framework.get_global_step(),
            learning_rate=lr,
            optimizer=optimizer)

        logging_hook = tf.train.LoggingTensorHook(tensors={
            'step':
            tf.contrib.framework.get_global_step(),
            'loss':
            loss,
            'lr':
            lr,
            'acc':
            accuracy
        },
                                                  every_n_iter=100)

        #saver = tf.train.Saver(var_list=None, keep_checkpoint_every_n_hours=1)
        saver = tf.train.Saver(var_list=None, max_to_keep=None)

        class _LearningRateSetterHook(tf.train.SessionRunHook):
            """Sets learning_rate based on global step."""
            def begin(self):
                self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**6
                #print(self.num_batches_per_epoch)

            def before_run(self, run_context):
                return tf.train.SessionRunArgs(
                    tf.contrib.framework.get_global_step(
                    ),  # Asks for global step value.
                    feed_dict={lr: self._lrn_rate})  # Sets learning rate

            def after_run(self, run_context, run_values):
                train_step = run_values.results
                self._lrn_rate = INITIAL_LEARNING_RATE
                #training_epoch = int(train_step/num_batches_per_epoch)
                #self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**int(train_step/num_batches_per_epoch/2.7)
                if train_step < 2 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE
                elif train_step < 4 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**1
                elif train_step < 6 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**2
                elif train_step < 9 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**3
                elif train_step < 12 * num_batches_per_epoch:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**4
                else:
                    self._lrn_rate = INITIAL_LEARNING_RATE * LEARNING_RATE_DECAY_FACTOR**5

        config = tf.ConfigProto(
            allow_soft_placement=False,
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=train_dir,
                        save_steps=num_batches_per_epoch,
                        saver=saver), logging_hook,
                    _LearningRateSetterHook()
                ],
                config=config,
                save_checkpoint_secs=3600) as mon_sess:
            #saver.restore(mon_sess,'/vol/vssp/facer2vm/people/Philipp/cnn_experiments/03/train/model.ckpt-21575')
            while True:  # not mon_sess.should_stop():
                mon_sess.run(train_op)
def test(saved_model_path, images, alphas=[]):
    with tf.Graph().as_default():

        image_path_tensor = tf.placeholder(tf.string)
        image_tf = tf_utils.single_input_image(image_path_tensor,
                                               db_dir + mean_name,
                                               image_size=256)
        #image_tf = tf_utils.single_input_image(image_path_tensor, '/user/HS204/m09113/my_project_folder/IJB_A/multi_iter75_reg30_256_conf13_sm/templates_merged_mean.png', image_size=256)
        #image_tf = tf_utils.single_input_image(image_path_tensor, '/user/HS204/m09113/my_project_folder/IJB_A/multi_iter75_reg30_256_conf13_sm/take_best1_merge_mean.png', image_size=256)
        #image_tf = tf_utils.single_input_image(image_path_tensor, '/user/HS204/m09113/my_project_folder/IJB_A/multi_iter75_reg30_256_conf13_sm/best3_merge_mean.png', image_size=256)
        image_tf = tf.expand_dims(image_tf, 0)

        # Build a Graph that computes the logits predictions from the inference model.
        if NUMBER_ALPHAS == 0 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 0:
            _, feature_vector_tensor = cnn_tf_graphs.inference(
                network="alex",
                mode=learn.ModeKeys.EVAL,
                batch_size=1,
                num_classes=10868,
                input_image_tensor=image_tf,
                image_size=256)

        elif NUMBER_ALPHAS == 1 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 0:
            alphas_tf = tf.placeholder(tf.float32, shape=(63))
            alphas_tf = tf.expand_dims(alphas_tf, 0)
            _, feature_vector_tensor = cnn_tf_graphs.inference(
                network="alex_with_alpha",
                mode=learn.ModeKeys.EVAL,
                batch_size=1,
                num_classes=10868,
                input_image_tensor=image_tf,
                input_alpha_tensor=alphas_tf)

        elif NUMBER_ALPHAS == 0 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 1:
            xyz_path_tensor = tf.placeholder(tf.string)
            xyz_tf = tf_utils.single_input_image(image_path_tensor,
                                                 db_dir + mean_name_xyz)
            xyz_tf = tf.expand_dims(xyz_tf, 0)
            stack_tf = tf.concat([image_tf, xyz_tf], axis=3)
            _, feature_vector_tensor = cnn_tf_graphs.inference(
                network="dcnn",
                mode=learn.ModeKeys.EVAL,
                batch_size=1,
                num_classes=10868,
                input_image_tensor=stack_tf)

        saver = tf.train.Saver()

        vectors = np.empty([len(images), feature_vector_tensor.shape[1]])

        config = tf.ConfigProto(
            allow_soft_placement=False,
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            print('restore model')
            saver.restore(sess, saved_model_path)
            print('restoring done')

            #print('we have',db_loader.num_examples_eval, 'images to evaluate')
            for idx, image_path in enumerate(images):
                if idx % 1000 == 0:
                    print(idx, 'of', len(images))
                if NUMBER_ALPHAS == 0 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 0:
                    vector = sess.run(
                        feature_vector_tensor,
                        feed_dict={image_path_tensor: image_path})
                elif NUMBER_ALPHAS == 1 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 0:
                    vector = sess.run(feature_vector_tensor,
                                      feed_dict={
                                          image_path_tensor:
                                          image_path,
                                          alphas_tf:
                                          np.expand_dims(np.array(alphas[idx]),
                                                         axis=0)
                                      })
                elif NUMBER_ALPHAS == 0 and NUMBER_IMAGES == 1 and NUMBER_XYZ == 1:
                    xyz_path = image_path.replace(file_ending, file_ending_xyz)
                    vector = sess.run(feature_vector_tensor,
                                      feed_dict={
                                          image_path_tensor: image_path,
                                          xyz_path_tensor: xyz_path
                                      })
                vectors[idx, :] = vector[0]
                #print ('got vector of length',len(vector[0]),'and sum',sum(vector[0]))
    return vectors