def FProp(self, theta, inputs, *extra_inputs): """Forward pass. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: A NestedMap: .split1 and .split2 corresponding to x1 and x2. *extra_inputs: additional inputs that will be passed to both f and g. No gradient will be computed for these inputs. Returns: outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2. f_seed: Scalar tensor. The step seed used in forward for the f block. g_seed: Scalar tensor. The step seed used in forward for the g block. """ f_seed = py_utils.GetStepSeed() f_out = self.f_block.FProp(theta.f_block, inputs.split2, *extra_inputs) z1 = inputs.split1 + f_out g_seed = py_utils.GetStepSeed() g_out = self.g_block.FProp(theta.g_block, z1, *extra_inputs) y2 = inputs.split2 + g_out # This is essential to make dy1 independent to y2. y1 = tf.identity(z1) return py_utils.NestedMap(split1=y1, split2=y2), f_seed, g_seed
def FProp(self, theta, inputs, *extra_inputs): initial_step_seed = py_utils.GetStepSeed() final_step_seed = py_utils.GenerateSeedFromName( tf.no_op(name='new_step_seed').name) num_layers = len(self.sub_layers) def Bak(inputs, outputs, d_outputs): """Backward step.""" del inputs # unused output_acts, step_seeds = outputs d_outputs = d_outputs[0] d_layer_thetas = [] for layer_idx in reversed(range(num_layers)): f_seed, g_seed = step_seeds[layer_idx] layer = self.sub_layers[layer_idx] layer_theta = theta.sub_layers[layer_idx] input_acts, d_inputs, d_theta = layer.ReverseAndGrad( layer_theta, output_acts, d_outputs, f_seed, g_seed, *extra_inputs) d_layer_thetas.append(d_theta) # Passes reconstructed inputs to the previous layer. output_acts = input_acts d_outputs = d_inputs py_utils.ResetStepSeed(final_step_seed) d_theta = py_utils.NestedMap() d_theta.sub_layers = list(reversed(d_layer_thetas)) extra_grads = [tf.zeros_like(t) for t in extra_inputs] return [ tf.zeros_like(initial_step_seed), d_theta, d_inputs, extra_grads ] def Fwd(xs): """Forward pass.""" initial_step_seed, theta, acts, extra_inputs = xs py_utils.ResetStepSeed(initial_step_seed) layer_step_seeds = [] for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, f_seed, g_seed = layer.FProp(layer_theta, acts, *extra_inputs) layer_step_seeds += [(f_seed, g_seed)] return [acts, layer_step_seeds] if self.params.custom_gradient: acts, _ = py_utils.CallDefun( Fwd, [initial_step_seed, theta, inputs, extra_inputs], Bak) py_utils.ResetStepSeed(final_step_seed) return acts else: acts = inputs for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, _, _ = layer.FProp(layer_theta, acts, *extra_inputs) return acts
def testRematerialize(self): # Test the dropout consistency between fprop and bprop. b = builder.Base.Params() b = b.Instantiate() start_block = layers.DeterministicDropoutLayer.Params().Set( name='start_dropout', keep_prob=0.7) # Build 4 dropout layers, each wrapped by RematerializeFn. num_blocks = 4 blocks = [] blocks_per_cell = 2 for i in range(num_blocks): blocks.append(layers.DeterministicDropoutLayer.Params().Set( name='dropout_{}'.format(i), keep_prob=0.7)) cells = [] while blocks: heads, blocks = blocks[:blocks_per_cell], blocks[blocks_per_cell:] cell_name = 'cell_{}'.format(len(cells)) cells.append( b._Rematerialize(name=cell_name, body=b._Seq(cell_name, *heads))) with self.session(use_gpu=False, graph=tf.Graph()) as sess: tf.random.set_seed(12345) p = b._Seq('test', start_block, *cells) mdl = p.Instantiate() # y = mdl.Frop(x * w) # Fake input x = tf.ones([4, 5]) # Construct weights. w = tf.get_variable('w', shape=[4, 5], initializer=tf.constant_initializer([[1] * 5] * 4)) y = mdl.FPropDefaultTheta(x * w) # Construct loss function such that gradients = final activation. # dy/dw = y = mdl.Frop(x * w) when w is 1. loss = tf.reduce_sum(y) grads = py_utils.ComputeGradients(loss, py_utils.NestedMap(w=w)) tf.global_variables_initializer().run() y_val, grads_val = sess.run([y, grads.Transform(tuple)]) grads_val = grads_val['w'][1] self.assertAllClose(y_val, grads_val) self.assertEqual(py_utils.GetStepSeed().eval(), 1553244033)