def testFCBackwardBounds(self):
     m = snt.Linear(1,
                    initializers={
                        'w': tf.constant_initializer(1.),
                        'b': tf.constant_initializer(2.),
                    })
     z = tf.constant([[1, 2, 3]], dtype=tf.float32)
     m(z)  # Connect to create weights.
     m = ibp.LinearFCWrapper(m)
     input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
     m.propagate_bounds(input_bounds)  # Create IBP bounds.
     crown_init_bounds = _generate_identity_spec([m], shape=(1, 1, 1))
     output_bounds = m.propagate_bounds(crown_init_bounds)
     concrete_bounds = output_bounds.concretize()
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         lw, uw, lb, ub, cl, cu = sess.run([
             output_bounds.lower.w, output_bounds.upper.w,
             output_bounds.lower.b, output_bounds.upper.b,
             concrete_bounds.lower, concrete_bounds.upper
         ])
         self.assertTrue(np.all(lw == 1.))
         self.assertTrue(np.all(lb == 2.))
         self.assertTrue(np.all(uw == 1.))
         self.assertTrue(np.all(ub == 2.))
         cl = cl.item()
         cu = cu.item()
         self.assertAlmostEqual(5., cl)
         self.assertAlmostEqual(11., cu)
 def testFCSymbolicBounds(self):
     m = snt.Linear(1,
                    initializers={
                        'w': tf.constant_initializer(1.),
                        'b': tf.constant_initializer(2.),
                    })
     z = tf.constant([[1, 2, 3]], dtype=tf.float32)
     m(z)  # Connect to create weights.
     m = ibp.LinearFCWrapper(m)
     input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
     input_bounds = ibp.SymbolicBounds.convert(input_bounds)
     output_bounds = input_bounds.propagate_through(m)
     concrete_bounds = ibp.IntervalBounds.convert(output_bounds)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         l, u, cl, cu = sess.run([
             output_bounds.lower, output_bounds.upper,
             concrete_bounds.lower, concrete_bounds.upper
         ])
         self.assertTrue(np.all(l.w == 1.))
         self.assertTrue(np.all(l.b == 2.))
         self.assertAlmostEqual([[0, 1, 2]], l.lower.tolist())
         self.assertAlmostEqual([[2, 3, 4]], l.upper.tolist())
         self.assertTrue(np.all(u.w == 1.))
         self.assertTrue(np.all(u.b == 2.))
         self.assertAlmostEqual([[0, 1, 2]], u.lower.tolist())
         self.assertAlmostEqual([[2, 3, 4]], u.upper.tolist())
         cl = cl.item()
         cu = cu.item()
         self.assertAlmostEqual(5., cl)
         self.assertAlmostEqual(11., cu)
def add_layer(net, module, inputs, flatten=False, batch_norm=None):
    if flatten:
        reshape_module = snt.BatchFlatten()
        outputs = reshape_module(inputs)
        net.append(
            ibp.BatchReshapeWrapper(reshape_module,
                                    outputs.shape[1:].as_list()), outputs,
            inputs)
        inputs = outputs

    outputs = module(inputs)
    if isinstance(module, AvgPool):
        module.__name__ = 'avg_pool'
        parameters = {
            'ksize': [1] + module.kernel_shape + [1],
            'padding': module.padding,
            'strides': [1] + module.strides + [1]
        }
        net.append(ibp.IncreasingMonotonicWrapper(module, **parameters),
                   outputs, inputs)
    elif isinstance(module, snt.Conv2D):
        net.append(ibp.LinearConv2dWrapper(module), outputs, inputs)
    elif isinstance(module, snt.Conv1D):
        net.append(ibp.LinearConv1dWrapper(module), outputs, inputs)
    elif isinstance(module, snt.Linear):
        net.append(ibp.LinearFCWrapper(module), outputs, inputs)
    else:
        net.append(ibp.IncreasingMonotonicWrapper(module), outputs, inputs)

    if batch_norm is not None:
        inputs = outputs
        outputs = batch_norm(inputs, is_training=False, test_local_stats=False)
        net.append(ibp.BatchNormWrapper(batch_norm), outputs, inputs)

    return outputs
 def testFCIntervalBounds(self):
     m = snt.Linear(1,
                    initializers={
                        'w': tf.constant_initializer(1.),
                        'b': tf.constant_initializer(2.),
                    })
     z = tf.constant([[1, 2, 3]], dtype=tf.float32)
     m(z)  # Connect to create weights.
     m = ibp.LinearFCWrapper(m)
     input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
     output_bounds = m.propagate_bounds(input_bounds)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         l, u = sess.run([output_bounds.lower, output_bounds.upper])
         l = l.item()
         u = u.item()
         self.assertAlmostEqual(5., l)
         self.assertAlmostEqual(11., u)
    def testCaching(self):
        m = snt.Linear(1,
                       initializers={
                           'w': tf.constant_initializer(1.),
                           'b': tf.constant_initializer(2.),
                       })
        z = tf.placeholder(shape=(1, 3), dtype=tf.float32)
        m(z)  # Connect to create weights.
        m = ibp.LinearFCWrapper(m)
        input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
        output_bounds = m.propagate_bounds(input_bounds)

        input_bounds.enable_caching()
        output_bounds.enable_caching()
        update_all_caches_op = tf.group(
            [input_bounds.update_cache_op, output_bounds.update_cache_op])

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            # Initialise the caches based on the model inputs.
            sess.run(update_all_caches_op, feed_dict={z: [[1., 2., 3.]]})

            l, u = sess.run([output_bounds.lower, output_bounds.upper])
            l = l.item()
            u = u.item()
            self.assertAlmostEqual(5., l)
            self.assertAlmostEqual(11., u)

            # Update the cache based on a different set of inputs.
            sess.run([output_bounds.update_cache_op],
                     feed_dict={z: [[2., 3., 7.]]})
            # We only updated the output bounds' cache.
            # This asserts that the computation depends on the underlying
            # input bounds tensor, not on cached version of it.
            # (Thus it doesn't matter what order the caches are updated.)

            l, u = sess.run([output_bounds.lower, output_bounds.upper])
            l = l.item()
            u = u.item()
            self.assertAlmostEqual(11., l)
            self.assertAlmostEqual(17., u)