def lr_schedule(epoch):
    lr = FLAGS.lr
    if epoch > epochs_inter[1]:
        lr *= 1e-2
    elif epoch > epochs_inter[0]:
        lr *= 1e-1
    print('Learning rate: ', lr)
    return lr


model_input = Input(shape=input_shape)

#dim of logtis is batchsize x dim_means
if version == 2:
    original_model, _, _, _, final_features = resnet_v2(input=model_input,
                                                        depth=depth,
                                                        num_classes=num_class,
                                                        use_BN=FLAGS.use_BN)
else:
    original_model, _, _, _, final_features = resnet_v1(input=model_input,
                                                        depth=depth,
                                                        num_classes=num_class,
                                                        use_BN=FLAGS.use_BN)

if FLAGS.use_BN == True:
    BN_name = '_withBN'
    print('Use BN in the model')
else:
    BN_name = '_noBN'
    print('Do not use BN in the model')

#Whether use target attack for adversarial training
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_class)
y_test = keras.utils.to_categorical(y_test, num_class)
y_test_target = keras.utils.to_categorical(y_test_target, num_class)

# Define input TF placeholder
y_place = tf.placeholder(tf.float32, shape=(None, num_class))
y_target = tf.placeholder(tf.float32, shape=(None, num_class))
sess = tf.Session()
keras.backend.set_session(sess)

model_input = Input(shape=input_shape)

#dim of logtis is batchsize x dim_means
if version == 2:
    original_model,_,_,_,final_features = resnet_v2(input=model_input, depth=depth, num_classes=num_class, \
                                                    use_BN=FLAGS.use_BN, use_dense=FLAGS.use_dense, use_leaky=FLAGS.use_leaky)
else:
    original_model,_,_,_,final_features = resnet_v1(input=model_input, depth=depth, num_classes=num_class, \
                                                    use_BN=FLAGS.use_BN, use_dense=FLAGS.use_dense, use_leaky=FLAGS.use_leaky)

if FLAGS.use_BN == True:
    BN_name = '_withBN'
    print('Use BN in the model')
else:
    BN_name = '_noBN'
    print('Do not use BN in the model')

#Whether use target attack for adversarial training
if FLAGS.use_target == False:
    is_target = ''
else: