def _CreateVariableStub(name, params, reuse=None, trainable=True, collections=None, default_seed=None, synchronization=None, aggregation=None): """Return a zero tensor of the right shape instead of creating variable.""" del reuse del default_seed del synchronization del aggregation dtype = params.dtype shape = py_utils.ToStaticShape(params.shape) # For total samples counters we have to actually create variables so that # we can access the 'value' attribute during construction. if 'total_samples' in name: var = tf.get_variable(name, shape, dtype, tf.constant_initializer(0), collections=collections, trainable=trainable, validate_shape=True) else: key = (tf.get_default_graph(), tuple(shape)) if key in variable_cache: var = variable_cache[key] else: var = tf.zeros(shape, dtype) variable_cache[key] = var return var, var
def _CreateVariableStub(name, params, reuse=None, trainable=True, init_wrapper=None, collections=None): """Return a zero tensor of the right shape instead of creating variable.""" del reuse dtype = params.dtype shape = py_utils.ToStaticShape(params.shape) if init_wrapper: var = init_wrapper(dtype, tf.constant_initializer(0, dtype=dtype)) # For total samples counters we have to actually create variables so that # we can access the 'value' attribute during construction. elif 'total_samples' in name: var = tf.get_variable(name, shape, dtype, tf.constant_initializer(0, dtype=dtype), collections=collections, trainable=trainable, validate_shape=True) else: key = hash(tuple(shape)) if key in variable_cache: var = variable_cache[key] else: var = tf.zeros(shape, dtype) variable_cache[key] = var return var, var
def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim, proj_obj): """Linear projection on the last dim of the input tensor along with pruning. This is a TPU efficient implementation to avoid reshaping inputs to Rank-2 tensor by using Einsum for the compute. Args: inputs: An input Tensor, the last dimension of which is input_dim. weight: A weight matrix with shape [input_dim, output_dim]. input_dim: An integer or a symbolic dim, the last dimension of the inputs. output_dim: An integer or a symbolic dim, the last dimension of the outputs. proj_obj: a ProjectionLayer object. Returns: An output Tensor of the same rank as inputs, the last dimension is output_dim. """ theta = proj_obj.theta p = proj_obj.params input_dim = int( symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim ) else input_dim) output_dim = int( symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim ) else output_dim) if (py_utils.use_tpu() and inputs.shape is not None and inputs.shape.rank is not None and inputs.shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: outputs = tf.matmul(inputs, weight) else: outputs = cls.GetEinSumResult(inputs, proj_obj) else: if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_input']: blocked_inputs = tf.reshape( inputs, py_utils.ToStaticShape( [-1, p.pruning_hparams_dict['input_block_size']])) compressed_inputs = tf.reshape( py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar), py_utils.ToStaticShape([ -1, input_dim // p.pruning_hparams_dict['input_compression_factor'] ])) else: compressed_inputs = tf.reshape( inputs, py_utils.ToStaticShape([-1, input_dim])) if p.pruning_hparams_dict['compression_option'] == 10: if p.pruning_hparams_dict['block_method'] == 'mask': intermediate_result = py_utils.Matmul( compressed_inputs, tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar)) elif p.pruning_hparams_dict['block_method'] == 'loop': num_blocks = p.pruning_hparams_dict[ 'block_compression_factor'] input_splitted = tf.split(compressed_inputs, num_blocks, axis=-1) output_splitted = [] for i, input_i in enumerate(input_splitted): output_splitted.append( py_utils.Matmul(input_i, theta.c_matrix_tfvar[i, :, :])) intermediate_result = tf.concat(output_splitted, axis=-1) else: intermediate_result = py_utils.Matmul(compressed_inputs, theta.c_matrix_tfvar) if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_output']: blocked_intermediate_result = tf.reshape( intermediate_result, py_utils.ToStaticShape([ -1, p.pruning_hparams_dict['output_block_size'] // p.pruning_hparams_dict['output_compression_factor'] ])) outputs = py_utils.Matmul(blocked_intermediate_result, theta.d_matrix_tfvar) else: outputs = intermediate_result outputs = tf.reshape( outputs, tf.concat([ tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32), py_utils.ToStaticShape([output_dim]) ], axis=0)) return outputs
def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim, proj_obj): """Linear projection on the last dim of the input tensor along with pruning. This is a TPU efficient implementation to avoid reshaping inputs to Rank-2 tensor by using Einsum for the compute. Args: inputs: An input Tensor, the last dimension of which is input_dim. weight: A weight matrix with shape [input_dim, output_dim]. input_dim: An integer or a symbolic dim, the last dimension of the inputs. output_dim: An integer or a symbolic dim, the last dimension of the outputs. proj_obj: a ProjectionLayer object. Returns: An output Tensor of the same rank as inputs, the last dimension is output_dim. """ theta = proj_obj.theta p = proj_obj.params input_dim = int( symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim ) else input_dim) output_dim = int( symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim ) else output_dim) if (py_utils.use_tpu() and inputs.shape is not None and inputs.shape.rank is not None and inputs.shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: outputs = tf.matmul(inputs, weight) else: s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz r = inputs.shape.rank outputs = cls.GetEinSumResult( inputs, weight, '{0}y,yz->{0}z'.format(s[:r - 1]), proj_obj) else: if p.pruning_hparams_dict['compress_input']: blocked_inputs = tf.reshape( inputs, py_utils.ToStaticShape( [-1, p.pruning_hparams_dict['input_block_size']])) compressed_inputs = tf.reshape( py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar), py_utils.ToStaticShape([ -1, input_dim // p.pruning_hparams_dict['input_compression_factor'] ])) else: compressed_inputs = tf.reshape( inputs, py_utils.ToStaticShape([-1, input_dim])) intermediate_result = py_utils.Matmul(compressed_inputs, theta.c_matrix_tfvar) if p.pruning_hparams_dict['compress_output']: blocked_intermediate_result = tf.reshape( intermediate_result, py_utils.ToStaticShape([ -1, p.pruning_hparams_dict['output_block_size'] // p.pruning_hparams_dict['output_compression_factor'] ])) outputs = py_utils.Matmul(blocked_intermediate_result, theta.d_matrix_tfvar) else: outputs = intermediate_result outputs = tf.reshape( outputs, tf.concat([ tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32), py_utils.ToStaticShape([output_dim]) ], axis=0)) return outputs