def create_pgd_graph(lb, ub, sess, tf_input, tf_output, target): # Replace graph tf_image = tf.Variable(lb, trainable=True) tf_output = ge.graph_replace(tf_output, {tf_input: tf_image + 0.0}) # Output diversification tf_dir = tf.placeholder(shape=(tf_output.shape[1]), dtype=tf.float64) tf_eps_init = tf.placeholder(shape=lb.shape, dtype=tf.float64) tf_init_error = tf.reduce_sum(tf_dir * tf_output) tf_init_grad = tf.gradients(tf_init_error, [tf_image])[0] tf_train_init = tf_image + tf_eps_init * tf.sign(tf_init_grad) tf_train_init = tf.assign(tf_image, tf_train_init) # PGD tf_train_error = tf.keras.utils.to_categorical( target, num_classes=tf_output.shape[-1]) tf_eps_pgd = tf.placeholder(shape=lb.shape, dtype=tf.float64) tf_train_error = tf.keras.losses.categorical_crossentropy(tf_train_error, tf_output, from_logits=True) tf_train_grad = tf.gradients(tf_train_error, [tf_image])[0] tf_train_pgd = tf_image - tf_eps_pgd * tf.sign(tf_train_grad) tf_train_pgd = tf.assign(tf_image, tf_train_pgd) # Clip tf_train_clip = tf.clip_by_value(tf_image, lb, ub) tf_train_clip = tf.assign(tf_image, tf_train_clip) # Seed tf_seed_pl = tf.placeholder(shape=lb.shape, dtype=tf.float64) tf_seed = tf.assign(tf_image, tf_seed_pl) return tf_image, tf_dir, tf_seed_pl, tf_eps_init, tf_eps_pgd, tf_output, tf_train_init, tf_train_pgd, tf_train_clip, tf_seed
def test_graph_replace_gradients(self): tmp_graph = tf.Graph() with tmp_graph.as_default(): w_tensor = tf.Variable(0.0, name="w") y_tensor = tf.multiply(tf.multiply(w_tensor, w_tensor, name="mul1"), w_tensor, name="mul2") grad_tensor = tf.gradients(y_tensor, w_tensor, name="gradient")[0] _ = tf.identity(grad_tensor, "grad") g = gde.Graph(tmp_graph) # Extract the operations. replacement_ts = {g["w/read"].output(0): g["grad"].output(0)} # Should not raise exception. res = gde.graph_replace(g["grad"].output(0), replacement_ts, dst_scope="res") self.assertNotEqual(res.name, g["grad"].output(0).name) after_graph = tf.Graph() with after_graph.as_default(): tf.import_graph_def(g.to_graph_def(), name="") gde.util.load_variables_to_tf_graph(g) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) g_val, res_val = sess.run([g["grad"].output(0).name, res.name]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
def test_graph_replace(self): g, a, a_new, c = self._create_replace_graph() c_new = gde.graph_replace(c, {a: a_new}) with g.to_tf_graph().as_default(): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) c_val, c_new_val = sess.run([c.name, c_new.name]) self.assertNear(c_val, 2.001, ERROR_TOLERANCE) self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE)
def test_graph_replace_missing(self): tmp_graph = tf.Graph() with tmp_graph.as_default(): a_tensor = tf.constant(1.0, name="a") b_tensor = tf.constant(2.0, name="b") _ = tf.add(a_tensor, 2 * b_tensor, name="c") _ = tf.constant(2.0, name="d") g = gde.Graph(tmp_graph) res = gde.graph_replace([g["b"].output(0), g["c"].output(0)], {g["a"].output(0): g["d"].output(0)}) self.assertEqual(res[0].name, "b:0") self.assertEqual(res[1].name, "c_1:0")
def test_graph_replace_dict(self): g, a, a_new, c = self._create_replace_graph() c_new = gde.graph_replace({"c": c}, {a: a_new}) self.assertTrue(isinstance(c_new, dict)) with g.to_tf_graph().as_default(): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) c_val, c_new_val = sess.run([c.name, {k: v.name for k, v in c_new.items()}]) self.assertTrue(isinstance(c_new_val, dict)) self.assertNear(c_val, 2.001, ERROR_TOLERANCE) self.assertNear(c_new_val["c"], 3.001, ERROR_TOLERANCE)
def test_graph_replace_named_tuple(self): g, a, a_new, c = self._create_replace_graph() one_tensor = collections.namedtuple("OneTensor", ["t"]) c_new = gde.graph_replace(one_tensor(c), {a: a_new}) self.assertTrue(isinstance(c_new, one_tensor))
def test_graph_replace_ordered_dict(self): g, a, a_new, c = self._create_replace_graph() c_new = gde.graph_replace(collections.OrderedDict({"c": c}), {a: a_new}) self.assertTrue(isinstance(c_new, collections.OrderedDict))