def testBatchNorm(self):
        decorator = ops.ConfigurableOps()
        kwargs = dict(center=False, scale=False)
        decorator_regular_output = decorator.batch_norm(self.inputs, **kwargs)
        decorator_zero_output = decorator.batch_norm(ops.VANISHED, **kwargs)
        tf_output = tf.contrib.layers.batch_norm(self.inputs, **kwargs)

        self.assertAllEqual(decorator_regular_output, tf_output)
        self.assertTrue(ops.is_vanished(decorator_zero_output))
Esempio n. 2
0
  def testTowerVanishes(self, parameterization):
    depth = self.inputs.shape.as_list()[3]
    decorator = ops.ConfigurableOps(parameterization=parameterization)

    net = decorator.conv2d(
        self.inputs, num_outputs=12, kernel_size=3, scope='first')
    net = decorator.conv2d(
        net, num_outputs=depth, kernel_size=1, scope='second')
    self.assertTrue(ops.is_vanished(net))
Esempio n. 3
0
 def testDefaultToZero(self):
   parameterization = {'first/Conv2D': 3}
   decorator = ops.ConfigurableOps(
       parameterization=parameterization, fallback_rule='zero')
   first = decorator.conv2d(
       self.inputs, num_outputs=12, kernel_size=3, scope='first')
   second = decorator.conv2d(self.inputs, 13, kernel_size=1, scope='second')
   self.assertEqual(3, first.shape.as_list()[3])
   self.assertTrue(ops.is_vanished(second))
   self.assertEqual(0, decorator.constructed_ops['second/Conv2D'])
Esempio n. 4
0
    def testBatchNorm(self):
        decorator = ops.ConfigurableOps()
        kwargs = dict(center=False, scale=False)
        decorator_regular_output = decorator.batch_norm(self.inputs, **kwargs)
        decorator_zero_output = decorator.batch_norm(ops.VANISHED, **kwargs)
        tf_output = layers.batch_norm(self.inputs, **kwargs)

        with self.cached_session():
            tf.global_variables_initializer().run()
            self.assertAllEqual(decorator_regular_output, tf_output)
        self.assertTrue(ops.is_vanished(decorator_zero_output))
Esempio n. 5
0
  def testPool(self):
    decorator = ops.ConfigurableOps()
    empty = ops.VANISHED
    pool_kwargs = dict(kernel_size=2, stride=2, padding='same', scope='pool')

    for fn_name in ['max_pool2d', 'avg_pool2d']:
      decorator_pool_fn = getattr(decorator, fn_name)
      decorator_regular_output = decorator_pool_fn(self.inputs, **pool_kwargs)
      decorator_zero_output = decorator_pool_fn(empty, **pool_kwargs)

      tf_pool_fn = getattr(layers, fn_name)
      tf_output = tf_pool_fn(self.inputs, **pool_kwargs)

      self.assertAllEqual(decorator_regular_output, tf_output)
      self.assertTrue(ops.is_vanished(decorator_zero_output))
Esempio n. 6
0
    def testConcatHijack(self):
        decorator = ops.ConfigurableOps()
        module = Fake()
        inputs = tf.ones([2, 3, 3, 5])
        empty = ops.VANISHED
        with self.assertRaises(ValueError):
            # empty will generate an error before the hijack.
            _ = module.concat([inputs, empty], 3).shape.as_list()

        # hijacking:
        ops.hijack_module_functions(decorator, module)
        # Verifying success of hijack.
        self.assertAllEqual(
            module.concat([inputs, empty], 3).shape.as_list(), [2, 3, 3, 5])
        self.assertTrue(ops.is_vanished(module.concat([empty, empty], 3)))
        self.assertAllEqual(
            module.concat([inputs, empty, inputs], 3).shape.as_list(),
            [2, 3, 3, 10])