def Elman(theta, state0, inputs): h0, w, b, x = state0.h, theta.w, theta.b, inputs.x xw = py_utils.Matmul(tf.concat([x, h0], axis=1), w) # 1st part # 2nd part padding = inputs.get('padding', None) h1 = _ApplyPadding(padding, v_no_pad=tf.sigmoid(xw + b), v_pad=state0.h) state1 = py_utils.NestedMap(h=h1) if padding is not None: state1.padding = inputs.padding return (state1, py_utils.NestedMap(h=h1))
def Grad(h0, w, b, x, padding, h1, dh1): del b dh1_orig = dh1 dh1 = _ApplyPadding(padding, dh1, tf.zeros_like(dh1, dtype=dh1.dtype)) # We hand-roll the gradient for the 2nd half of the cell as a demo. # h1 = tf.sigmoid(xw + b) # 𝛔'(x) = ((1 - 𝛔(x)) * 𝛔(x)) dxwb = (dh1 * (1 - h1) * h1) dxw, db = dxwb, tf.reduce_sum(dxwb, axis=0) # Uses tf.gradient for the 1nd half of the cell as a demo. xw = py_utils.Matmul(tf.concat([x, h0], axis=1), w) dh0, dx, dw = tf.gradients(ys=[xw], xs=[h0, x, w], grad_ys=[dxw]) dh0 = _ApplyPadding(padding, dh0, dh1_orig) return dh0, dx, dw, db
def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim, proj_obj): """Linear projection on the last dim of the input tensor along with pruning. This is a TPU efficient implementation to avoid reshaping inputs to Rank-2 tensor by using Einsum for the compute. Args: inputs: An input Tensor, the last dimension of which is input_dim. weight: A weight matrix with shape [input_dim, output_dim]. input_dim: An integer or a symbolic dim, the last dimension of the inputs. output_dim: An integer or a symbolic dim, the last dimension of the outputs. proj_obj: a ProjectionLayer object. Returns: An output Tensor of the same rank as inputs, the last dimension is output_dim. """ theta = proj_obj.theta p = proj_obj.params input_dim = int( symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim ) else input_dim) output_dim = int( symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim ) else output_dim) if (py_utils.use_tpu() and inputs.shape is not None and inputs.shape.rank is not None and inputs.shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: outputs = tf.matmul(inputs, weight) else: outputs = cls.GetEinSumResult(inputs, proj_obj) else: if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_input']: blocked_inputs = tf.reshape( inputs, py_utils.ToStaticShape( [-1, p.pruning_hparams_dict['input_block_size']])) compressed_inputs = tf.reshape( py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar), py_utils.ToStaticShape([ -1, input_dim // p.pruning_hparams_dict['input_compression_factor'] ])) else: compressed_inputs = tf.reshape( inputs, py_utils.ToStaticShape([-1, input_dim])) if p.pruning_hparams_dict['compression_option'] == 10: if p.pruning_hparams_dict['block_method'] == 'mask': intermediate_result = py_utils.Matmul( compressed_inputs, tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar)) elif p.pruning_hparams_dict['block_method'] == 'loop': num_blocks = p.pruning_hparams_dict[ 'block_compression_factor'] input_splitted = tf.split(compressed_inputs, num_blocks, axis=-1) output_splitted = [] for i, input_i in enumerate(input_splitted): output_splitted.append( py_utils.Matmul(input_i, theta.c_matrix_tfvar[i, :, :])) intermediate_result = tf.concat(output_splitted, axis=-1) else: intermediate_result = py_utils.Matmul(compressed_inputs, theta.c_matrix_tfvar) if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_output']: blocked_intermediate_result = tf.reshape( intermediate_result, py_utils.ToStaticShape([ -1, p.pruning_hparams_dict['output_block_size'] // p.pruning_hparams_dict['output_compression_factor'] ])) outputs = py_utils.Matmul(blocked_intermediate_result, theta.d_matrix_tfvar) else: outputs = intermediate_result outputs = tf.reshape( outputs, tf.concat([ tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32), py_utils.ToStaticShape([output_dim]) ], axis=0)) return outputs
def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim, proj_obj): """Linear projection on the last dim of the input tensor along with pruning. This is a TPU efficient implementation to avoid reshaping inputs to Rank-2 tensor by using Einsum for the compute. Args: inputs: An input Tensor, the last dimension of which is input_dim. weight: A weight matrix with shape [input_dim, output_dim]. input_dim: An integer or a symbolic dim, the last dimension of the inputs. output_dim: An integer or a symbolic dim, the last dimension of the outputs. proj_obj: a ProjectionLayer object. Returns: An output Tensor of the same rank as inputs, the last dimension is output_dim. """ theta = proj_obj.theta p = proj_obj.params input_dim = int( symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim ) else input_dim) output_dim = int( symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim ) else output_dim) if (py_utils.use_tpu() and inputs.shape is not None and inputs.shape.rank is not None and inputs.shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: outputs = tf.matmul(inputs, weight) else: s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz r = inputs.shape.rank outputs = cls.GetEinSumResult( inputs, weight, '{0}y,yz->{0}z'.format(s[:r - 1]), proj_obj) else: if p.pruning_hparams_dict['compress_input']: blocked_inputs = tf.reshape( inputs, py_utils.ToStaticShape( [-1, p.pruning_hparams_dict['input_block_size']])) compressed_inputs = tf.reshape( py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar), py_utils.ToStaticShape([ -1, input_dim // p.pruning_hparams_dict['input_compression_factor'] ])) else: compressed_inputs = tf.reshape( inputs, py_utils.ToStaticShape([-1, input_dim])) intermediate_result = py_utils.Matmul(compressed_inputs, theta.c_matrix_tfvar) if p.pruning_hparams_dict['compress_output']: blocked_intermediate_result = tf.reshape( intermediate_result, py_utils.ToStaticShape([ -1, p.pruning_hparams_dict['output_block_size'] // p.pruning_hparams_dict['output_compression_factor'] ])) outputs = py_utils.Matmul(blocked_intermediate_result, theta.d_matrix_tfvar) else: outputs = intermediate_result outputs = tf.reshape( outputs, tf.concat([ tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32), py_utils.ToStaticShape([output_dim]) ], axis=0)) return outputs