def RandomLayer(layer_a, layer_b, prob_a): """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" condition = tl.Serial( tl.RandomUniform(), tl.Fn('SmallerThan', lambda x: x < prob_a) ) return tl.Cond(condition, layer_a, layer_b)
def test_weights_and_state(self): cond = SmallerThan(3.0) true = tl.Dense(5) false = tl.Dense(5) different = tl.Dense(5) layer = tl.Cond(cond, true, false) xs = (np.array(2.), np.array([0., 1., 2.])) layer.init(shapes.signature(xs)) # weights self.assertEqual(as_list(layer.weights), as_list((cond.weights, true.weights, false.weights))) self.assertNotEqual(as_list(true.weights), as_list(false.weights)) self.assertNotEqual(as_list(true.weights), as_list(different.weights)) false.weights = true.weights self.assertEqual(as_list(layer.weights), as_list((cond.weights, true.weights, true.weights))) layer.weights = (cond.weights, true.weights, different.weights) self.assertEqual( as_list(layer.weights), as_list((cond.weights, true.weights, different.weights))) # state self.assertEqual(as_list(layer.state), as_list((cond.state, true.state, false.state))) # just check if simple assignments (setter from base.Layer) work correctly # with Cond.init_weights_and_state ; all states are empty so there is no # point in checking equality false.state = true.state layer.state = (cond.state, true.state, different.state)
def test_condition_func_default_false(self): cond = SmallerThan(3.0) true = DivideBy(2.) layer = tl.Cond(cond, true) xs = (np.array(4.), np.array([4., 12.])) layer.init(shapes.signature(xs)) ys = layer(xs) self.assertEqual(as_list(ys), [4., 12.])
def test_exception_run1(self): # We expect exactly one input. cond = SmallerThan(3.0) true = ReturnConst(2.) false = ReturnConst(5.) def init_and_run(layer, xs): layer.init(shapes.signature(xs)) layer(xs) # It will pass with one input. xs = np.array(4.) layer = tl.Cond(cond, true, false) init_and_run(layer, xs) # It will fail with zero or two inputs. for xs in ((), (np.array(4.), np.array([4., 12.]))): layer = tl.Cond(cond, true, false) # pylint: disable=cell-var-from-loop self.assertRaises(Exception, lambda: init_and_run(layer, xs))
def test_complex_blocks(self): cond = ReturnConst(True) true = DivideBy(2.) false = DivideBy(4.) layer = tl.Cond(cond, true, false) xs = [np.arange(5).astype(np.float32)] layer.init(shapes.signature(xs)) ys = layer(xs) self.assertEqual(as_list(ys), [0., 0.5, 1.0, 1.5, 2.0])
def test_basic_false(self): cond = ReturnConst(False) true = ReturnConst([2]) false = ReturnConst([5]) layer = tl.Cond(cond, true, false) layer.init(()) xs = tuple() ys = layer(xs) self.assertEqual(as_list(ys), 5)
def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if n_layers_to_keep > current_layer_num LargerThan(float(current_layer_num)), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial() ) # stack: embedding, n_layers_to_keep )
def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if random() > skip_fraction OR layer not in skip_mode ... LargerThan(skip_fraction if skip_mode_fun(current_layer_num) else 0.0), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)) # else: noop (implicit) ) # stack: embedding, n_layers_to_keep )
def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding tl.RandomUniform(0., 1, sync=True), # stack: random_uniform, embedding tl.Cond( # if random_uniform > skip_fraction LargerThan(skip_fraction[current_layer_num] if mode == 'train' else 0.0), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial() ) # stack: embedding )
def test_exception_n_in(self): cond = SmallerThan(3.0) true = ReturnConst(2.) false = DivideBy(2.) self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false))
def test_exception_n_out(self): cond = SmallerThan(3.0) true = DivideBy(2.) false = tl.Dup() self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false))