def main(argv=None):
    print("This script is used to compute accuracy!")
    if len(argv) < 2:
        print("No argv! You need to assign train/test/adv data as below:")
        print("Try: python compute_acc.py train, python compute_acc.py test\n")
        return

    batch_shape = [FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, 3]
    nb_classes = FLAGS.num_classes

    tf.logging.set_verbosity(tf.logging.INFO)
    config = tf.ConfigProto()
    # allocate 50% of GPU memory
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5

    with tf.Graph().as_default():
        print("Prepare graph...")
        x_input = tf.placeholder(tf.float32, shape=batch_shape)
        noise_input = gaussian_noise_layer(x_input, .0)

        with slim.arg_scope(inception.inception_v1_arg_scope()):
            _, end_points = inception.inception_v1(noise_input,
                                                   num_classes=nb_classes,
                                                   is_training=False)
        predicted_labels = tf.argmax(end_points['Predictions'], 1)

        print("Restore Model...")
        saver = tf.train.Saver(slim.get_model_variables())
        session_creator = tf.train.ChiefSessionCreator(
            config=config,
            scaffold=tf.train.Scaffold(saver=saver),
            checkpoint_filename_with_path=FLAGS.checkpoint_path)

        print("Run computation...")
        with tf.train.MonitoredSession(
                session_creator=session_creator) as sess:
            if argv[1] == 'test':
                INPUT_DIR = FLAGS.test_data_dir
                FLAGS.output_file = './result/test_accuracy.txt'
            elif argv[1] == 'train':
                INPUT_DIR = FLAGS.train_data_dir
                FLAGS.output_file = './result/train_accuracy.txt'
            elif argv[1] == 'adv':
                INPUT_DIR = FLAGS.adv_data_dir
                FLAGS.output_file = './result/adversarial_accuracy.txt'

            data_generator = load_path_label(INPUT_DIR,
                                             batch_shape,
                                             shuffle=False)

            acc_dict = {}
            for i in range(FLAGS.num_classes):
                acc_dict[i] = [0, 0]

            for images, true_labels, _ in tqdm(data_generator):
                labels = sess.run(predicted_labels,
                                  feed_dict={x_input: images})
                for i in range(len(true_labels)):
                    acc_dict[true_labels[i]][1] += 1
                    if labels[i] == true_labels[i]:
                        acc_dict[true_labels[i]][0] += 1

            print("Compute accuracy...")
            with open(FLAGS.output_file, 'w') as f:
                total_true = 0
                total_count = 0
                for i in range(FLAGS.num_classes):
                    total_true += acc_dict[i][0]
                    total_count += acc_dict[i][1]
                    if acc_dict[i][1] == 0:
                        f.writelines("class: %d, accuracy: %d/%d = %.3f \n" %
                                     (i, acc_dict[i][0], acc_dict[i][1], 0))
                    else:
                        f.writelines("class: %d, accuracy: %d/%d = %.3f \n" %
                                     (i, acc_dict[i][0], acc_dict[i][1],
                                      acc_dict[i][0] / acc_dict[i][1]))

                print("Total accuracy: %.3f \n" % (total_true / total_count))
                f.writelines("Total accuracy: %.3f \n" %
                             (total_true / total_count))

            print('Save accuracy result to %s' % FLAGS.output_file)
def main(_):
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)

    batch_shape = [FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, 3]
    nb_classes = FLAGS.num_classes
    input_images = tf.placeholder(
        tf.float32, [None, FLAGS.image_height, FLAGS.image_width, 3])
    input_labels = tf.placeholder(tf.float32, [None, nb_classes])

    learning_rate = FLAGS.learning_rate
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)

    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, end_points = inception.inception_v1(input_images,
                                                    num_classes=110,
                                                    is_training=True)

    variables_to_restore = slim.get_variables_to_restore()

    loss_op = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                      labels=input_labels)
    total_loss_op = tf.reduce_mean(loss_op)
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op)
    tf.summary.scalar('total loss', total_loss_op)

    summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(variables_to_restore)
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())

    init = tf.global_variables_initializer()

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op = slim.assign_from_checkpoint_fn(
            FLAGS.pretrained_model_path,
            slim.get_trainable_variables(),
            ignore_missing_vars=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        if FLAGS.restore:
            sess.run(init)
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)

        for epoch in range(FLAGS.max_epochs):
            start = time.time()
            data_generator = load_path_label('./datasets/train_labels.txt',
                                             batch_shape,
                                             onehot=True)
            for step in range(FLAGS.max_steps):
                data = next(data_generator)
                _, total_loss, res = sess.run(
                    [train_op, total_loss_op, summary_op],
                    feed_dict={
                        input_images: data[0],
                        input_labels: data[1]
                    })
                if step % 50 == 0:
                    summary_writer.add_summary(res, step)

                if np.isnan(total_loss):
                    print('Loss diverged, stop training')
                    break

                if step % 10 == 0:
                    avg_time_per_step = (time.time() - start) / 10
                    avg_examples_per_second = (10 * FLAGS.batch_size) / (
                        time.time() - start)
                    start = time.time()
                    print(
                        'Step {:06d}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                        .format(step, total_loss, avg_time_per_step,
                                avg_examples_per_second))

            saver.save(sess,
                       FLAGS.checkpoint_path + 'robust_model',
                       global_step=epoch)
            tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic')
    ac_saver = tf.train.Saver(ac_var_list, max_to_keep=3)
    if not tf.gfile.Exists(FLAGS.ddpg_checkpoint_path):
        tf.gfile.MkDir(FLAGS.ddpg_checkpoint_path)
    else:
        if tf.train.latest_checkpoint(FLAGS.ddpg_checkpoint_path):
            ac_saver.restore(
                sess, tf.train.latest_checkpoint(FLAGS.ddpg_checkpoint_path))

    # initialization classifier
    classifier = Classifier([None, 224, 224, 3], FLAGS.num_classes)

    M = Memory(FLAGS.MEMORY_CAPACITY)
    var = 0.001  # control exploration
    start = time.time()
    data_generator = load_path_label(
        FLAGS.input_dir, [1, FLAGS.image_height, FLAGS.image_width, 3])
    for episode in range(FLAGS.max_ep_steps):
        # ep_reward = 0.0
        step = 0
        done = False
        (images, _, filepaths) = next(data_generator)
        features, labels = classifier.extract_feature(images)
        noise_images = images
        while not done:
            actions = actor.choose_action(features)
            actions = np.clip(
                np.random.normal(actions, var), -FLAGS.EPSILON, FLAGS.EPSILON
            )  # add randomness to action selection for exploration
            noise_images = np.clip(noise_images + actions, -1, 1)
            r, l2_dist, pre_labels = classifier.get_reward(
                images, noise_images, labels)