def add_layer(net, module, inputs, flatten=False, batch_norm=None): if flatten: reshape_module = snt.BatchFlatten() outputs = reshape_module(inputs) net.append( ibp.BatchReshapeWrapper(reshape_module, outputs.shape[1:].as_list()), outputs, inputs) inputs = outputs outputs = module(inputs) if isinstance(module, AvgPool): module.__name__ = 'avg_pool' parameters = { 'ksize': [1] + module.kernel_shape + [1], 'padding': module.padding, 'strides': [1] + module.strides + [1] } net.append(ibp.IncreasingMonotonicWrapper(module, **parameters), outputs, inputs) elif isinstance(module, snt.Conv2D): net.append(ibp.LinearConv2dWrapper(module), outputs, inputs) elif isinstance(module, snt.Conv1D): net.append(ibp.LinearConv1dWrapper(module), outputs, inputs) elif isinstance(module, snt.Linear): net.append(ibp.LinearFCWrapper(module), outputs, inputs) else: net.append(ibp.IncreasingMonotonicWrapper(module), outputs, inputs) if batch_norm is not None: inputs = outputs outputs = batch_norm(inputs, is_training=False, test_local_stats=False) net.append(ibp.BatchNormWrapper(batch_norm), outputs, inputs) return outputs
def testReluIntervalBounds(self): m = tf.nn.relu z = tf.constant([[-2, 3]], dtype=tf.float32) m = ibp.IncreasingMonotonicWrapper(m) input_bounds = ibp.IntervalBounds(z - 1., z + 1.) output_bounds = m.propagate_bounds(input_bounds) with self.test_session() as sess: l, u = sess.run([output_bounds.lower, output_bounds.upper]) self.assertAlmostEqual([[0., 2.]], l.tolist()) self.assertAlmostEqual([[0., 4.]], u.tolist())
def testReluBackwardBounds(self): m = tf.nn.relu z = tf.constant([[-2, 3]], dtype=tf.float32) m = ibp.IncreasingMonotonicWrapper(m) input_bounds = ibp.IntervalBounds(z - 1., z + 1.) m.propagate_bounds(input_bounds) # Create IBP bounds. crown_init_bounds = _generate_identity_spec([m], shape=(1, 2, 2), dimension=2) output_bounds = m.propagate_bounds(crown_init_bounds) concrete_bounds = output_bounds.concretize() with self.test_session() as sess: l, u = sess.run([concrete_bounds.lower, concrete_bounds.upper]) self.assertAlmostEqual([[0., 2.]], l.tolist()) self.assertAlmostEqual([[0., 4.]], u.tolist())