def testSymbolToTensorMap(self): """Tests that cell_fn can rely on the contextual symbol-to-tensor map.""" x = symbolic.Symbol('x') y = symbolic.Symbol('y') def PlusWXT(theta, state, inputs): """state.value += theta.w * x * inputs.t.""" next_state = py_utils.NestedMap() x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) next_state.value = state.value + theta.w * x_tensor * inputs.t return next_state, py_utils.NestedMap() def PlusWXTGrad(theta, state0, inputs, extras, dstate1): """Gradient function for PlusWXT.""" del state0, extras x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t) dstate0 = py_utils.NestedMap(value=dstate1.value) dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor) return dtheta, dstate0, dinputs, None with self.session() as sess: theta = py_utils.NestedMap(w=tf.constant(1., name='w')) state0 = py_utils.NestedMap(value=tf.constant(0., name='value')) inputs = py_utils.NestedMap(t=tf.constant([1., 2., 3.], name='t')) # With automatic cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, { x: tf.constant(7., name='x7'), y: 8 }): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 7) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.)) # With manual cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: tf.constant(5., name='x5')}): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT, cell_grad=PlusWXTGrad) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 5) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.))
def testEvalExpr(self): x = symbolic.NewSymbol('x') y = symbolic.NewSymbol('y') xy = x * y a = tf.placeholder(tf.float32) b = tf.placeholder(tf.float32) with symbolic.SymbolToValueMap({x: 2, y: 3}): self.assertEqual(symbolic.EvalExpr(xy), 6) # The inner map overrides the outer map. with symbolic.SymbolToValueMap({x: a, y: b}): ab = symbolic.EvalExpr(xy) self.assertEqual(symbolic.EvalExpr(xy), 6) # EvalExpr can also evaluate a symbolic expression to a # Tensor. self.assertIsInstance(ab, tf.Tensor) with self.session() as sess: self.assertEqual(12, sess.run(ab, {a: 3, b: 4})) with self.assertRaises(Exception): # EvalExpr does not support partial evaluation. with symbolic.SymbolToValueMap({y: 3}): symbolic.EvalExpr(xy)
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.to_int64(tf.shape(inputs)[:-1])) * tf.to_int64( symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims * p.output_dims)) * 2) use_tpu = py_utils.use_tpu() if use_tpu and inputs.shape is not None and inputs.shape.rank < 26: # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: return tf.matmul(inputs, theta.w) else: s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz r = inputs.shape.rank return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, theta.w) input_dim = py_utils.GetShape(inputs)[-1] act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w) output_dim = tf.shape(theta.w)[-1] act = tf.reshape(act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0)) return act
def PlusWXTGrad(theta, state0, inputs, extras, dstate1): """Gradient function for PlusWXT.""" del state0, extras x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t) dstate0 = py_utils.NestedMap(value=dstate1.value) dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor) return dtheta, dstate0, dinputs, None
def CreateVariable(self, name, var_params, theta_fn=None, *args, **kwargs): """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) `theta_fn` is used to apply a simple transformation on the created variable's value before used by the forward computation. E.g., to add the global variational noise according to this layer's parameter, one can do:: def __init__(self, ...): # A layer's constructor self.CreateVariable( name='weight', var_params=py_utils.WeightParams(shape=[100, 100]), theta_fn=self.AddGlobalVN) Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. theta_fn: A python function that takes a variable's value and returns a new value to be used later for computation. Its signature must be (tf.Tensor) -> (tf.Tensor). *args: List of args passed to `.py_utils.CreateVariable`. **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape if (var_params.shape and any(symbolic.IsExpr(dim) for dim in var_params.shape)): var_params.shape = symbolic.EvalExpr(var_params.shape) value, var = py_utils.CreateVariable(name, var_params, *args, **kwargs) self._private_vars[name] = var if theta_fn is not None: value = theta_fn(value) self._private_theta[name] = value
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) * tf.cast( symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims * p.output_dims), tf.int64) * 2) return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims, p.output_dims)
def testEvalExpr(self): x = symbolic.NewSymbol('x') y = symbolic.NewSymbol('y') xy = x * y with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}): self.assertEqual(symbolic.EvalExpr(symbolic.STATIC_VALUES, xy), 6) # The inner map overrides the outer map. with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, { x: 5, y: 6 }): self.assertEqual(symbolic.EvalExpr(symbolic.STATIC_VALUES, xy), 30) # Back to the outer map. self.assertEqual(symbolic.EvalExpr(symbolic.STATIC_VALUES, xy), 6) # EvalExpr can also evaluate a symbolic expression to a # Tensor. a = tf.placeholder(tf.float32) b = tf.placeholder(tf.float32) with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: a, y: b}): with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, { x: 2, y: 3 }): # Value maps of different types do not affect each other. self.assertEqual(symbolic.EvalExpr(symbolic.STATIC_VALUES, xy), 6) ab = symbolic.EvalExpr(symbolic.TENSOR_VALUES, xy) self.assertIsInstance(ab, tf.Tensor) with self.session() as sess: self.assertEqual(12, sess.run(ab, {a: 3, b: 4})) with self.assertRaises(Exception): # EvalExpr does not support partial evaluation. with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {y: 3}): symbolic.EvalExpr(symbolic.STATIC_VALUES, xy)
def PlusWXT(theta, state, inputs): """state.value += theta.w * x * inputs.t.""" next_state = py_utils.NestedMap() x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) next_state.value = state.value + theta.w * x_tensor * inputs.t return next_state, py_utils.NestedMap()