def testBatchNorm(self): decorator = ops.ConfigurableOps() kwargs = dict(center=False, scale=False) decorator_regular_output = decorator.batch_norm(self.inputs, **kwargs) decorator_zero_output = decorator.batch_norm(ops.VANISHED, **kwargs) tf_output = tf.contrib.layers.batch_norm(self.inputs, **kwargs) self.assertAllEqual(decorator_regular_output, tf_output) self.assertTrue(ops.is_vanished(decorator_zero_output))
def testTowerVanishes(self, parameterization): depth = self.inputs.shape.as_list()[3] decorator = ops.ConfigurableOps(parameterization=parameterization) net = decorator.conv2d( self.inputs, num_outputs=12, kernel_size=3, scope='first') net = decorator.conv2d( net, num_outputs=depth, kernel_size=1, scope='second') self.assertTrue(ops.is_vanished(net))
def testDefaultToZero(self): parameterization = {'first/Conv2D': 3} decorator = ops.ConfigurableOps( parameterization=parameterization, fallback_rule='zero') first = decorator.conv2d( self.inputs, num_outputs=12, kernel_size=3, scope='first') second = decorator.conv2d(self.inputs, 13, kernel_size=1, scope='second') self.assertEqual(3, first.shape.as_list()[3]) self.assertTrue(ops.is_vanished(second)) self.assertEqual(0, decorator.constructed_ops['second/Conv2D'])
def testBatchNorm(self): decorator = ops.ConfigurableOps() kwargs = dict(center=False, scale=False) decorator_regular_output = decorator.batch_norm(self.inputs, **kwargs) decorator_zero_output = decorator.batch_norm(ops.VANISHED, **kwargs) tf_output = layers.batch_norm(self.inputs, **kwargs) with self.cached_session(): tf.global_variables_initializer().run() self.assertAllEqual(decorator_regular_output, tf_output) self.assertTrue(ops.is_vanished(decorator_zero_output))
def testPool(self): decorator = ops.ConfigurableOps() empty = ops.VANISHED pool_kwargs = dict(kernel_size=2, stride=2, padding='same', scope='pool') for fn_name in ['max_pool2d', 'avg_pool2d']: decorator_pool_fn = getattr(decorator, fn_name) decorator_regular_output = decorator_pool_fn(self.inputs, **pool_kwargs) decorator_zero_output = decorator_pool_fn(empty, **pool_kwargs) tf_pool_fn = getattr(layers, fn_name) tf_output = tf_pool_fn(self.inputs, **pool_kwargs) self.assertAllEqual(decorator_regular_output, tf_output) self.assertTrue(ops.is_vanished(decorator_zero_output))
def testConcatHijack(self): decorator = ops.ConfigurableOps() module = Fake() inputs = tf.ones([2, 3, 3, 5]) empty = ops.VANISHED with self.assertRaises(ValueError): # empty will generate an error before the hijack. _ = module.concat([inputs, empty], 3).shape.as_list() # hijacking: ops.hijack_module_functions(decorator, module) # Verifying success of hijack. self.assertAllEqual( module.concat([inputs, empty], 3).shape.as_list(), [2, 3, 3, 5]) self.assertTrue(ops.is_vanished(module.concat([empty, empty], 3))) self.assertAllEqual( module.concat([inputs, empty, inputs], 3).shape.as_list(), [2, 3, 3, 10])