model_dic_baseline[str(i)] = resnet_v1(input=model_input_baseline,
                                           depth=depth,
                                           num_classes=num_classes,
                                           dataset=FLAGS.dataset)
    model_out_baseline.append(model_dic_baseline[str(i)][2])

model_output_baseline = keras.layers.concatenate(model_out_baseline)

model_baseline = Model(inputs=model_input_baseline,
                       outputs=model_output_baseline)
model_ensemble_baseline = keras.layers.Average()(model_out_baseline)
model_ensemble_baseline = Model(input=model_input_baseline,
                                output=model_ensemble_baseline)

#Get individual models
wrap_ensemble = KerasModelWrapper(model_ensemble)
wrap_ensemble_baseline = KerasModelWrapper(model_ensemble_baseline)

#Load model
model.load_weights(filepath)
model_baseline.load_weights(filepath_baseline)

# Consider the attack to be constant
eval_par = {'batch_size': 100}

preds = model_ensemble(x)
preds_baseline = model_ensemble_baseline(x)
acc = model_eval(sess, x, y, preds, x_test, y_test, args=eval_par)
acc_baseline = model_eval(sess,
                          x,
                          y,
print('Restore model checkpoints from %s' % filepath)

#Creat model
model_input = Input(shape=input_shape)
model_dic = {}
model_out = []
for i in range(FLAGS.num_models):
    model_dic[str(i)] = resnet_v1(input=model_input, depth=depth, num_classes=num_classes, dataset=FLAGS.dataset)
    model_out.append(model_dic[str(i)][2])
model_output = keras.layers.concatenate(model_out)
model = Model(inputs=model_input, outputs=model_output)
model_ensemble = keras.layers.Average()(model_out)
model_ensemble = Model(inputs=model_input, outputs=model_ensemble)

#Get individual models
wrap_ensemble = KerasModelWrapper(model_ensemble)
eval_par = {'batch_size': 1}

sess.run(tf.initialize_all_variables())
model.load_weights(filepath)

print("start attack")
num_samples = 100
preds = wrap_ensemble.get_probs(x)

print('Normal acc is:')
print(
    model_eval(
        sess,
        x,
        y,

model_input = Input(shape=input_shape)
model_dic = {}
model_out = []
for i in range(FLAGS.num_models):
    model_dic[str(i)] = resnet_v1(input=model_input,
                                  depth=depth,
                                  num_classes=num_classes,
                                  dataset=FLAGS.dataset)
    model_out.append(model_dic[str(i)][2])
model_output = keras.layers.concatenate(model_out)
model = Model(input=model_input, output=model_output)
model_ensemble = keras.layers.Average()(model_out)
model_ensemble = Model(inputs=model_input, outputs=model_ensemble)
wrap_ensemble = KerasModelWrapper(model_ensemble, num_class=num_classes)

eps = tf.random_uniform((), 0.01, 0.05)
if FLAGS.attack_method == 'MadryEtAl':
    att = attacks.MadryEtAl(wrap_ensemble)
    att_params = {
        'eps': eps,
        'eps_iter': eps / 10.,
        'clip_min': clip_min,
        'clip_max': clip_max,
        'nb_iter': 10
    }
elif FLAGS.attack_method == 'MomentumIterativeMethod':
    att = attacks.MomentumIterativeMethod(wrap_ensemble)
    att_params = {
        'eps': eps,
    model_input = Input(shape=input_shape)
    model_dic = {}
    model_out = []
    for i in range(FLAGS.num_models):
        # resnet_v1 return: model, inputs, outputs, logits, final_features
        model_i_output = get_model(inputs=model_input,
                                   model=FLAGS.model,
                                   dataset=FLAGS.dataset)
        model_out.append(model_i_output)
        model_dic[i] = Model(inputs=model_input, outputs=model_i_output)

    model_output = keras.layers.concatenate(model_out)
    model = Model(inputs=model_input, outputs=model_output)
    model_ensemble = keras.layers.Average()(model_out)
    model_ensemble = Model(inputs=model_input, outputs=model_ensemble)
    wrap_ensemble = KerasModelWrapper(model_ensemble, num_class=num_classes)
else:
    assert (FLAGS.defense == 'adv')
    model_input = Input(shape=input_shape)
    model_output = get_model(inputs=model_input,
                             model=FLAGS.model,
                             dataset=FLAGS.dataset)
    model = Model(inputs=model_input, outputs=model_output)
    wrap_ensemble = KerasModelWrapper(model, num_class=num_classes)
""" Generate adversarial examples
    Args: 
        eps 

    Return: 
        adv_x: adv of model_ensemble 
        adv_x1: adv of model_1