Beispiel #1
0
def RandomLayer(layer_a, layer_b, prob_a):
  """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`."""
  condition = tl.Serial(
      tl.RandomUniform(),
      tl.Fn('SmallerThan', lambda x: x < prob_a)
      )
  return tl.Cond(condition, layer_a, layer_b)
    def test_weights_and_state(self):
        cond = SmallerThan(3.0)
        true = tl.Dense(5)
        false = tl.Dense(5)
        different = tl.Dense(5)
        layer = tl.Cond(cond, true, false)
        xs = (np.array(2.), np.array([0., 1., 2.]))
        layer.init(shapes.signature(xs))

        # weights
        self.assertEqual(as_list(layer.weights),
                         as_list((cond.weights, true.weights, false.weights)))
        self.assertNotEqual(as_list(true.weights), as_list(false.weights))
        self.assertNotEqual(as_list(true.weights), as_list(different.weights))

        false.weights = true.weights
        self.assertEqual(as_list(layer.weights),
                         as_list((cond.weights, true.weights, true.weights)))

        layer.weights = (cond.weights, true.weights, different.weights)
        self.assertEqual(
            as_list(layer.weights),
            as_list((cond.weights, true.weights, different.weights)))
        # state
        self.assertEqual(as_list(layer.state),
                         as_list((cond.state, true.state, false.state)))
        # just check if simple assignments (setter from base.Layer) work correctly
        # with Cond.init_weights_and_state ; all states are empty so there is no
        # point in checking equality
        false.state = true.state
        layer.state = (cond.state, true.state, different.state)
 def test_condition_func_default_false(self):
     cond = SmallerThan(3.0)
     true = DivideBy(2.)
     layer = tl.Cond(cond, true)
     xs = (np.array(4.), np.array([4., 12.]))
     layer.init(shapes.signature(xs))
     ys = layer(xs)
     self.assertEqual(as_list(ys), [4., 12.])
Beispiel #4
0
 def test_exception_run1(self):
   # We expect exactly one input.
   cond = SmallerThan(3.0)
   true = ReturnConst(2.)
   false = ReturnConst(5.)
   def init_and_run(layer, xs):
     layer.init(shapes.signature(xs))
     layer(xs)
   # It will pass with one input.
   xs = np.array(4.)
   layer = tl.Cond(cond, true, false)
   init_and_run(layer, xs)
   # It will fail with zero or two inputs.
   for xs in ((), (np.array(4.), np.array([4., 12.]))):
     layer = tl.Cond(cond, true, false)
     # pylint: disable=cell-var-from-loop
     self.assertRaises(Exception, lambda: init_and_run(layer, xs))
 def test_complex_blocks(self):
     cond = ReturnConst(True)
     true = DivideBy(2.)
     false = DivideBy(4.)
     layer = tl.Cond(cond, true, false)
     xs = [np.arange(5).astype(np.float32)]
     layer.init(shapes.signature(xs))
     ys = layer(xs)
     self.assertEqual(as_list(ys), [0., 0.5, 1.0, 1.5, 2.0])
 def test_basic_false(self):
     cond = ReturnConst(False)
     true = ReturnConst([2])
     false = ReturnConst([5])
     layer = tl.Cond(cond, true, false)
     layer.init(())
     xs = tuple()
     ys = layer(xs)
     self.assertEqual(as_list(ys), 5)
Beispiel #7
0
 def ConditionedBlock(current_layer_num):
   return tl.Serial(
       # stack: embedding, n_layers_to_keep
       tl.Select([1, 0, 1]),  # n_layers_to_keep, embedding, n_layers_to_keep
       tl.Cond(
           # if n_layers_to_keep > current_layer_num
           LargerThan(float(current_layer_num)),
           # then: run block
           tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
               d_model, d_ff, n_heads, dropout, [], mode, ff_activation)),
           # else: run noop
           tl.Serial()
           )
       # stack: embedding, n_layers_to_keep
       )
Beispiel #8
0
 def ConditionedBlock(current_layer_num):
   return tl.Serial(
       # stack: embedding, n_layers_to_keep
       tl.Select([1, 0, 1]),  # n_layers_to_keep, embedding, n_layers_to_keep
       tl.Cond(
           # if random() > skip_fraction OR layer not in skip_mode ...
           LargerThan(skip_fraction if skip_mode_fun(current_layer_num)
                      else 0.0),
           # then: run block
           tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
               d_model, d_ff, n_heads, dropout, [], mode, ff_activation))
           # else: noop (implicit)
           )
       # stack: embedding, n_layers_to_keep
       )
Beispiel #9
0
 def ConditionedBlock(current_layer_num):
   return tl.Serial(
       # stack: embedding
       tl.RandomUniform(0., 1, sync=True),
       # stack: random_uniform, embedding
       tl.Cond(
           # if random_uniform > skip_fraction
           LargerThan(skip_fraction[current_layer_num] if mode == 'train'
                      else 0.0),
           # then: run block
           tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
               d_model, d_ff, n_heads, dropout, [], mode, ff_activation)),
           # else: run noop
           tl.Serial()
           )
       # stack: embedding
       )
 def test_exception_n_in(self):
     cond = SmallerThan(3.0)
     true = ReturnConst(2.)
     false = DivideBy(2.)
     self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false))
 def test_exception_n_out(self):
     cond = SmallerThan(3.0)
     true = DivideBy(2.)
     false = tl.Dup()
     self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false))