def run_eval(val_loop=2,
             number_of_classes=5,
             batch_size=40,
             sigma=0.6,
             checkpoint_dir=None,
             dataset_dir=None):

    #定义输入变量
    tf.reset_default_graph()
    with tf.Graph().as_default() as g:

        # 生成训练数据
        img_val, label_val = gen_data_batch(dataset_dir=dataset_dir,
                                            batch_size=batch_size,
                                            Train=False)

        # 推理、训练、准确度
        x_img_in = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
        x_label = tf.placeholder(tf.int64, [batch_size])
        label_hot = tf.one_hot(x_label, number_of_classes)

        label_advers = number_of_classes - 1 - x_label
        label_advers_hot = tf.one_hot(label_advers, number_of_classes)
        layer_name1 = 'resnet_v2_152/block4'
        target_conv_layer, target_grad_ac, target_grad_yc, logits_cl, end_points = forward_tran(
            x_img=x_img_in,
            number_of_classes=number_of_classes,
            layer_name=layer_name1,
            Training=False)
        conv_in = end_points[layer_name1]
        mask_conv_imgs,_=batch_mask_images(batch_size=batch_size,img_conv = conv_in,target_conv_layer=target_conv_layer,\
                                    target_grad_ac=target_grad_ac,target_grad_yc=target_grad_yc,sigma=sigma)

        _, _, _, logits_adver, _ = forward_tran_advers(
            x_img=mask_conv_imgs,
            number_of_classes=number_of_classes,
            Training=False)

        accuracy = evaluation(logits_cl, label_hot)
        accuracy_adver = evaluation(logits_adver, label_advers_hot)
        prob_max = tf.argmax(tf.nn.softmax(logits_cl), 1)
        prob_adver_max = tf.argmax(tf.nn.softmax(logits_adver), 1)
        #
        config = tf.ConfigProto()  # 配置GPU参数
        config.gpu_options.allow_growth = True  # 动态分配GPU资源
        #config.gpu_options.per_process_gpu_memory_fraction = 0.85   # 占用GPU90%的显存
        sess = tf.Session(config=config)
        #
        start_time = time.time()
        with sess:  #开始一个会话
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('checkpoint restored from [{0}]'.format(
                    ckpt.model_checkpoint_path))
                print('model_checkpoint_path Loading success')
            else:
                print('No checkpoint before training')
                pass

            accuracy_list = []
            for i in range(val_loop):
                _img_val, _label_val = sess.run([img_val, label_val])
                feed_vars = {x_img_in: _img_val, x_label: _label_val}
                _, _accuracy,_accuracy_adver, _prob_max,_prob_adver_max,_label_val = sess.run([logits_cl,\
                                                                                               accuracy,\
                                                                                               accuracy_adver,\
                                                                                               prob_max,\
                                                                                               prob_adver_max,\
                                                                                               label_val],\
                                                                                              feed_dict=feed_vars)
                accuracy_list.append(_accuracy)
                accuracy_list.append(_accuracy_adver)
                print(
                    'setp:%d, test, (accuracy_A, accuracy_B) = (%.3f, %.3f)' %
                    (i, _accuracy, _accuracy_adver))
                #print('step:{}, test_real_label is:{}'.format(i,_label_val) )
                #print('step:{}, test_prob_label is:{}'.format(i, _prob_max) )

            coord.request_stop()
            coord.join(threads)
        sess.close()

        aver_accuracy = sum(accuracy_list) / len(accuracy_list)
        print('At last, test, aver_accuracy:%.4f' % (aver_accuracy))
        print('time use is %d second' % (time.time() - start_time))
        return aver_accuracy
def run_training(number_of_classes = 5,\
                 batch_size=10,\
                 learning_rate=0.00001,\
                 num_train_img=90,\
                 num_epoc=200,\
                 hide_prob=0.1,\
                 sigma = 0.6,\
                 dataset_dir=None,\
                 logs_train_dir=None,\
                 checkpoint_dir=None,\
                 checkpoint_exclude_scopes=None):

    if not os.path.exists(logs_train_dir):
        os.makedirs(logs_train_dir)

    #定义输入变量
    tf.reset_default_graph()
    with tf.Graph().as_default() as g:

        # 生成训练数据
        img_train, label_train = gen_data_batch(dataset_dir=dataset_dir,
                                                batch_size=batch_size,
                                                Train=True)

        # 推理、训练、准确度
        #logits, end_points = inference(x_img, number_of_classes,True)
        x_img_in = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
        x_label = tf.placeholder(tf.int64, [batch_size])
        label_hot = tf.one_hot(x_label, number_of_classes)

        label_advers = number_of_classes - 1 - x_label
        label_advers_hot = tf.one_hot(label_advers, number_of_classes)
        layer_name1 = 'resnet_v2_152/block4'
        target_conv_layer,target_grad_ac,target_grad_yc,logits_cl,end_points = forward_tran(x_img=x_img_in,\
                                                                                            label_index=x_label,\
                                                                                            number_of_classes=number_of_classes,\
                                                                                            layer_name=layer_name1,\
                                                                                            Training=True)
        conv_in = end_points[layer_name1]
        mask_conv_imgs,_=batch_mask_images(batch_size=batch_size,\
                                                         img_input = None,\
                                                         img_conv = conv_in,\
                                                         target_conv_layer=target_conv_layer,\
                                                         target_grad_ac=target_grad_ac,\
                                                         target_grad_yc=target_grad_yc,\
                                                         sigma=sigma)

        _,_,_,logits_adver,_ = forward_tran_advers(x_img = mask_conv_imgs,\
                                                   label_index = label_advers,\
                                                   number_of_classes = number_of_classes,\
                                                   Training = True)

        total_loss,logit_cl_loss,logit_adver_loss,L2_loss = losses(logits_cl,\
                                                                   logits_adver,\
                                                                   label_hot,\
                                                                   label_advers_hot,\
                                                                   number_of_classes)

        train_op, global_step = trainning(total_loss, learning_rate)
        train_accuracy = (evaluation(logits_cl, label_hot) +
                          evaluation(logits_adver, label_advers_hot)) / 2.0
        #
        config = tf.ConfigProto()  # 配置GPU参数
        config.gpu_options.allow_growth = True  # 动态分配GPU资源
        #config.gpu_options.per_process_gpu_memory_fraction = 0.85   # 占用GPU90%的显存
        sess = tf.Session(config=config)
        #
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
        start_time = time.time()
        with sess:  #开始一个会话
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            saver = tf.train.Saver()
            exclusions = []
            if checkpoint_exclude_scopes:
                exclusions = [
                    scope.strip()
                    for scope in checkpoint_exclude_scopes.split(',')
                ]

            variables_to_restore = []
            for var in slim.get_model_variables():
                excluded = False
                for exclusion in exclusions:
                    if var.op.name.startswith(exclusion):
                        excluded = True
                        break
                if not excluded:
                    variables_to_restore.append(var)
            print('*******run_trainging11*******')
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('checkpoint restored from [{0}]'.format(
                    ckpt.model_checkpoint_path))
                print('model_checkpoint_path Loading success')
            elif checkpoint_dir:
                read_weights_restore = slim.assign_from_checkpoint_fn(
                    checkpoint_dir, variables_to_restore)
                read_weights_restore(sess)
                print('checkpoint restored from [{0}]'.format(checkpoint_dir))
            else:
                print('No checkpoint before training')
                pass
            print('******run_training12******')
            one_epock_step = num_train_img // batch_size
            MAX_STEP = num_epoc * one_epock_step
            for step in range(num_epoc * num_train_img // batch_size):
                #for step in range(1):

                _img_train, _label_train = sess.run([img_train, label_train])
                patch_img = patch_epock_img(img= _img_train,step=step,epock_step=one_epock_step,\
                                            hide_prob=hide_prob)

                feed_vars = {x_img_in: patch_img, x_label: _label_train}
                _global_step, _ , _total_loss,_logit_cl_loss, _logit_adver_loss,\
                _L2_loss,_accuracy,summary_str = sess.run([global_step,train_op,total_loss,\
                                                           logit_cl_loss,logit_adver_loss,\
                                                           L2_loss,train_accuracy,summary_op],\
                                                           feed_dict=feed_vars)
                #每迭代50次,打印出一次结果
                if step % 100 == 0:
                    print('Step %d, total_loss = %.2f, cl_loss=%.2f, adver_loss=%.2f, L2_loss=%.2f, accuracy = %.3f'\
                          %(_global_step, _total_loss,_logit_cl_loss, _logit_adver_loss, _L2_loss, _accuracy ))
                    train_writer.add_summary(summary_str, _global_step)
                #每迭代700次,利用saver.save()保存一次模型文件,以便测试的时候使用
                if ((step % 1200 == 0) and step > 0) or (step + 1) == MAX_STEP:
                    checkpoint_path = os.path.join(logs_train_dir,
                                                   'model_Resnet152.ckpt')
                    saver.save(sess, checkpoint_path)
                    print('%.2f sec/step' % ((time.time() - start_time) /
                                             (step + 1e-5)))
            coord.request_stop()
            coord.join(threads)
        sess.close()

    print('time use is %d second' % (time.time() - start_time))