Exemplo n.º 1
0
def eval(INPUT_DATA):
	with tf.name_scope('data'):
		print("###############get data#################")
		prosessed_data = np.load(INPUT_DATA)
		testing_imgs = prosessed_data[4]
		testing_labels = prosessed_data[5]
		print("###############finish get data!#################")


	with tf.Graph().as_default() as g:
		#定义输入输出格式
		patch = tf.placeholder(tf.float32, [ifd_train.BATCH_SIZE, ifd_train.PATCH_SIZE, ifd_train.PATCH_SIZE, ifd_train.INCHANNEL], name='test_patch')
		labels = tf.placeholder(tf.int64, [ifd_train.BATCH_SIZE], name='test_label')

		#调用网络
		logits = inf.inference(patch, n_classes=ifd_train.N_CLASSES, TRAIN_FLAG=False)
		
		with tf.name_scope('accuracy'):
			label_one_hot = tf.one_hot(labels, depth=ifd_train.N_CLASSES, on_value=1)
			preds = tf.nn.softmax(logits)
			correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(label_one_hot, 1))
			accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

		saver = tf.train.Saver()

		while True:
			with tf.Session() as sess:
				print("############  load the model  ################")
				ckpt = tf.train.get_checkpoint_state(ifd_train.MODEL_SAVE_PATH)
				if ckpt and ckpt.model_checkpoint_path:
					saver.restore(sess, ckpt.model_checkpoint_path)

					global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
					
					test_EPOCH = len(testing_imgs)
					acc_total_test = 0
					m_idx = 0
					
					TP = 0
					FP = 0
					TN = 0
					FN = 0
					
					# one epoch one img
					for epoch in range(test_EPOCH):
						# load data
						patch_data, label_data, m_idx = utils.get_patch(testing_imgs[epoch], testing_labels[epoch],
						                                                ifd_train.PATCH_SIZE, ifd_train.STRIDE,
						                                                ifd_train.IMG_SIZE,
						                                                ifd_train.INCHANNEL, ifd_train.THRESHOLD, m_idx)
						N_PATCH = patch_data.shape[0]
						
						acc_epoch, pred_labels = sess.run([accuracy, preds], feed_dict={patch: patch_data,
						                                                                labels: label_data})
						acc_epoch /= N_PATCH
						print('Average Accuracy at step {0}: {1}'.format(epoch, acc_epoch))
						acc_total_test += acc_epoch
						
						# update the actual_Flag
						actual_Flag = update_actualFlag(testing_labels[epoch])
						# update the predict_Flag
						pred_l = np.asarray(pred_labels)
						pred_l = np.reshape(pred_l, [-1, 2])
						predict_Flag = update_predictFlag(pred_l)
						print('actual_Flag:{0}: predict_Flag:{1}'.format(actual_Flag, predict_Flag))
						
						# update the actual_Flag
						actual_Flag = update_actualFlag(testing_labels[epoch])
						# update the predict_Flag
						pred_l = np.asarray(pred_labels)
						pred_l = np.reshape(pred_l, [-1, 2])
						predict_Flag = update_predictFlag(pred_l)
						#update TP, FP, TN, FN
						if predict_Flag and actual_Flag:
							TP += 1
						elif predict_Flag and (not actual_Flag):
							FP += 1
						elif (not predict_Flag) and (not actual_Flag):
							TN += 1
						else:
							FN += 1
					
					#the final accuracy
					acc = acc_total_test / test_EPOCH
					print('Average Accuracy at step {0}: {1}'.format(global_step, acc))
				
					write2file(ifd_train.OUTFILE, ifd_train.STRIDE, ifd_train.THRESHOLD, ifd_train.PATCH_SIZE,
					           TP, FP, TN, FN, acc)
				else:
					print('No checkpoint file found')
			return
Exemplo n.º 2
0
def train():
    with tf.name_scope('data'):
        print("###############get data#################")
        prosessed_data = np.load(INPUT_DATA)
        training_imgs = prosessed_data[0]
        training_labels = prosessed_data[1]
        validation_imgs = prosessed_data[2]
        validation_labels = prosessed_data[3]
        testing_imgs = prosessed_data[4]
        testing_labels = prosessed_data[5]
        print("%d train, %d validation, %d test" %
              (len(training_imgs), len(validation_imgs), len(testing_imgs)))
        print("###############finish get data!#################")

    #define input / output
    patch = tf.placeholder(tf.float32,
                           [BATCH_SIZE, PATCH_SIZE, PATCH_SIZE, INCHANNEL],
                           name='input-patch')
    labels = tf.placeholder(tf.int64, [BATCH_SIZE], name='input-label')

    #import the model
    TRAIN_FLAG = True
    logits = inf.inference(patch, N_CLASSES, TRAIN_FLAG)
    global_step = tf.Variable(0, trainable=False)

    with tf.name_scope('loss'):
        label_one_hot = tf.one_hot(labels, depth=N_CLASSES, on_value=1)
        entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=label_one_hot, logits=logits)
        #entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(labels, 1), logits=logits)
        loss = tf.reduce_mean(entropy, name='loss')

    with tf.name_scope('accuracy'):
        preds = tf.nn.softmax(logits)
        correct_preds = tf.equal(tf.argmax(preds, 1),
                                 tf.argmax(label_one_hot, 1))
        accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

    with tf.name_scope('learning_rate'):
        learning_rate = tf.train.exponential_decay(
            learning_rate=LEARNING_RATE_BASE,
            global_step=global_step,
            decay_steps=BATCH_SIZE,
            decay_rate=lEARNING_RATE_DECAY)

    with tf.name_scope('summaries'):
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.histogram('histogram loss', loss)
        summary_op = tf.summary.merge_all()

    train_op = tf.train.AdamOptimizer(learning_rate).minimize(
        loss, global_step=global_step)

    saver = tf.train.Saver()
    writer = tf.summary.FileWriter(GRAPH_PATH, tf.get_default_graph())
    with tf.Session(config=config) as sess:
        utils.safe_mkdir('checkpoints')
        utils.safe_mkdir(MODEL_SAVE_PATH)
        sess.run(tf.global_variables_initializer())

        print("############  check the saver  ################")
        ckpt = tf.train.get_checkpoint_state(
            os.path.dirname(MODEL_SAVE_PATH + '/checkpoint'))
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('#####  seccessfully load the saver #######')
        step = global_step.eval()
        print("############  start training  ################")
        train_EPOCH = len(training_imgs)
        for epoch in range(train_EPOCH):
            start_time = time.time()
            total_loss = 0
            TRAIN_FLAG = True

            start = 0
            end = BATCH_SIZE
            m_idx = 0  #统计一下每张图中被篡改过patch的数量
            # load data
            patch_data, label_data, m_idx = utils.get_patch(
                training_imgs[epoch], training_labels[epoch], PATCH_SIZE,
                STRIDE, IMG_SIZE, INCHANNEL, THRESHOLD, m_idx)
            N_PATCH = patch_data.shape[0]
            print('Patch num: {0} Modified Patch num: {1}'.format(
                label_data.shape, m_idx))
            train_step = int(N_PATCH / BATCH_SIZE)
            #train with each patch
            for batch_idx in range(train_step):
                _, l, summaries = sess.run([train_op, loss, summary_op],
                                           feed_dict={
                                               patch: patch_data[start:end],
                                               labels: label_data[start:end]
                                           })
                writer.add_summary(summaries, global_step=step)
                # each LOSS_OUT_STEPS steps: calculate and show the average loss
                if (step + 1) % LOSS_OUT_STEPS == 0:
                    print('Loss at step {0}: {1}'.format(step, l))
                step += 1
                total_loss += l
                start += BATCH_SIZE
                end += BATCH_SIZE

            # each SAVE_NUM epoches :calculate the accurancy on the validation data, and save the model
            if (epoch + 1) % SAVE_NUM == 0:
                print('Average loss at epoch {0}: {1}'.format(
                    epoch, total_loss / (train_step * PATCH_SIZE)))
                print('Took: {0} seconds for one epoch'.format(time.time() -
                                                               start_time))
                saver.save(sess,
                           save_path=os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                           global_step=step)
def eval(INPUT_DATA):
    with tf.name_scope('data'):
        print("###############get data#################")
        prosessed_data = np.load(INPUT_DATA)
        testing_imgs = prosessed_data[4]
        testing_labels = prosessed_data[5]
        print("###############finish get data!#################")

    with tf.Graph().as_default() as g:
        #定义输入输出格式
        patch = tf.placeholder(tf.float32, [
            ifd_train.BATCH_SIZE, ifd_train.PATCH_SIZE, ifd_train.PATCH_SIZE,
            ifd_train.INCHANNEL
        ],
                               name='test_patch')
        labels = tf.placeholder(tf.int64, [ifd_train.BATCH_SIZE],
                                name='test_label')

        #调用网络
        logits = inf.inference(patch,
                               n_classes=ifd_train.N_CLASSES,
                               TRAIN_FLAG=False)

        with tf.name_scope('accuracy'):
            label_one_hot = tf.one_hot(labels,
                                       depth=ifd_train.N_CLASSES,
                                       on_value=1)
            preds = tf.nn.softmax(logits)
            correct_preds = tf.equal(tf.argmax(preds, 1),
                                     tf.argmax(label_one_hot, 1))
            accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

        saver = tf.train.Saver()

        while True:
            with tf.Session() as sess:
                print("############  load the model  ################")
                ckpt = tf.train.get_checkpoint_state(ifd_train.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)

                    global_step = ckpt.model_checkpoint_path.split(
                        '/')[-1].split('-')[-1]

                    test_EPOCH = len(testing_imgs)
                    acc_total = 0
                    for epoch in range(test_EPOCH):
                        acc_total_epoch = 0
                        start = 0
                        end = ifd_train.BATCH_SIZE
                        # load data
                        patch_data, label_data = utils.get_patch(
                            testing_imgs[epoch], testing_labels[epoch],
                            ifd_train.PATCH_SIZE, ifd_train.STRIDE,
                            ifd_train.IMG_SIZE, ifd_train.INCHANNEL,
                            ifd_train.THRESHOLD)
                        N_PATCH = patch_data.shape[0]
                        test_step = int(N_PATCH / ifd_train.BATCH_SIZE)

                        for idx in range(test_step):

                            accuracy_score = sess.run(
                                accuracy,
                                feed_dict={
                                    patch: patch_data[start:end],
                                    labels: label_data[start:end]
                                })
                            accuracy_score /= ifd_train.BATCH_SIZE
                            acc_total_epoch += accuracy_score
                            start += ifd_train.BATCH_SIZE
                            end += ifd_train.BATCH_SIZE
                        print('Average Accuracy at step {0}: {1}'.format(
                            epoch, acc_total_epoch / test_step))
                        acc_total += acc_total_epoch
                    print('Average Accuracy at step {0}: {1}'.format(
                        global_step, acc_total / test_EPOCH))
                else:
                    print('No checkpoint file found')
            return