Exemple #1
0
def create_pgd_graph(lb, ub, sess, tf_input, tf_output, target):
    # Replace graph
    tf_image = tf.Variable(lb, trainable=True)
    tf_output = ge.graph_replace(tf_output, {tf_input: tf_image + 0.0})

    # Output diversification
    tf_dir = tf.placeholder(shape=(tf_output.shape[1]), dtype=tf.float64)
    tf_eps_init = tf.placeholder(shape=lb.shape, dtype=tf.float64)
    tf_init_error = tf.reduce_sum(tf_dir * tf_output)
    tf_init_grad = tf.gradients(tf_init_error, [tf_image])[0]
    tf_train_init = tf_image + tf_eps_init * tf.sign(tf_init_grad)
    tf_train_init = tf.assign(tf_image, tf_train_init)

    # PGD
    tf_train_error = tf.keras.utils.to_categorical(
        target, num_classes=tf_output.shape[-1])
    tf_eps_pgd = tf.placeholder(shape=lb.shape, dtype=tf.float64)
    tf_train_error = tf.keras.losses.categorical_crossentropy(tf_train_error,
                                                              tf_output,
                                                              from_logits=True)
    tf_train_grad = tf.gradients(tf_train_error, [tf_image])[0]
    tf_train_pgd = tf_image - tf_eps_pgd * tf.sign(tf_train_grad)
    tf_train_pgd = tf.assign(tf_image, tf_train_pgd)

    # Clip
    tf_train_clip = tf.clip_by_value(tf_image, lb, ub)
    tf_train_clip = tf.assign(tf_image, tf_train_clip)

    # Seed
    tf_seed_pl = tf.placeholder(shape=lb.shape, dtype=tf.float64)
    tf_seed = tf.assign(tf_image, tf_seed_pl)

    return tf_image, tf_dir, tf_seed_pl, tf_eps_init, tf_eps_pgd, tf_output, tf_train_init, tf_train_pgd, tf_train_clip, tf_seed
  def test_graph_replace_gradients(self):
    tmp_graph = tf.Graph()
    with tmp_graph.as_default():
      w_tensor = tf.Variable(0.0, name="w")
      y_tensor = tf.multiply(tf.multiply(w_tensor, w_tensor, name="mul1"),
                             w_tensor, name="mul2")
      grad_tensor = tf.gradients(y_tensor, w_tensor, name="gradient")[0]
      _ = tf.identity(grad_tensor, "grad")

    g = gde.Graph(tmp_graph)

    # Extract the operations.
    replacement_ts = {g["w/read"].output(0): g["grad"].output(0)}

    # Should not raise exception.
    res = gde.graph_replace(g["grad"].output(0), replacement_ts,
                                         dst_scope="res")

    self.assertNotEqual(res.name, g["grad"].output(0).name)
    after_graph = tf.Graph()
    with after_graph.as_default():
      tf.import_graph_def(g.to_graph_def(), name="")
      gde.util.load_variables_to_tf_graph(g)
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        g_val, res_val = sess.run([g["grad"].output(0).name, res.name])
    self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
    self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
  def test_graph_replace(self):
    g, a, a_new, c = self._create_replace_graph()
    c_new = gde.graph_replace(c, {a: a_new})

    with g.to_tf_graph().as_default():
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        c_val, c_new_val = sess.run([c.name, c_new.name])

    self.assertNear(c_val, 2.001, ERROR_TOLERANCE)
    self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE)
 def test_graph_replace_missing(self):
   tmp_graph = tf.Graph()
   with tmp_graph.as_default():
     a_tensor = tf.constant(1.0, name="a")
     b_tensor = tf.constant(2.0, name="b")
     _ = tf.add(a_tensor, 2 * b_tensor, name="c")
     _ = tf.constant(2.0, name="d")
   g = gde.Graph(tmp_graph)
   res = gde.graph_replace([g["b"].output(0), g["c"].output(0)],
                                        {g["a"].output(0): g["d"].output(0)})
   self.assertEqual(res[0].name, "b:0")
   self.assertEqual(res[1].name, "c_1:0")
  def test_graph_replace_dict(self):
    g, a, a_new, c = self._create_replace_graph()
    c_new = gde.graph_replace({"c": c}, {a: a_new})
    self.assertTrue(isinstance(c_new, dict))

    with g.to_tf_graph().as_default():
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        c_val, c_new_val = sess.run([c.name,
                                     {k: v.name for k, v in c_new.items()}])

    self.assertTrue(isinstance(c_new_val, dict))
    self.assertNear(c_val, 2.001, ERROR_TOLERANCE)
    self.assertNear(c_new_val["c"], 3.001, ERROR_TOLERANCE)
 def test_graph_replace_named_tuple(self):
   g, a, a_new, c = self._create_replace_graph()
   one_tensor = collections.namedtuple("OneTensor", ["t"])
   c_new = gde.graph_replace(one_tensor(c), {a: a_new})
   self.assertTrue(isinstance(c_new, one_tensor))
 def test_graph_replace_ordered_dict(self):
   g, a, a_new, c = self._create_replace_graph()
   c_new = gde.graph_replace(collections.OrderedDict({"c": c}), {a: a_new})
   self.assertTrue(isinstance(c_new, collections.OrderedDict))