Пример #1
0
 def _CreateVariableStub(name,
                         params,
                         reuse=None,
                         trainable=True,
                         collections=None,
                         default_seed=None,
                         synchronization=None,
                         aggregation=None):
     """Return a zero tensor of the right shape instead of creating variable."""
     del reuse
     del default_seed
     del synchronization
     del aggregation
     dtype = params.dtype
     shape = py_utils.ToStaticShape(params.shape)
     # For total samples counters we have to actually create variables so that
     # we can access the 'value' attribute during construction.
     if 'total_samples' in name:
         var = tf.get_variable(name,
                               shape,
                               dtype,
                               tf.constant_initializer(0),
                               collections=collections,
                               trainable=trainable,
                               validate_shape=True)
     else:
         key = (tf.get_default_graph(), tuple(shape))
         if key in variable_cache:
             var = variable_cache[key]
         else:
             var = tf.zeros(shape, dtype)
             variable_cache[key] = var
     return var, var
Пример #2
0
 def _CreateVariableStub(name,
                         params,
                         reuse=None,
                         trainable=True,
                         init_wrapper=None,
                         collections=None):
     """Return a zero tensor of the right shape instead of creating variable."""
     del reuse
     dtype = params.dtype
     shape = py_utils.ToStaticShape(params.shape)
     if init_wrapper:
         var = init_wrapper(dtype, tf.constant_initializer(0, dtype=dtype))
     # For total samples counters we have to actually create variables so that
     # we can access the 'value' attribute during construction.
     elif 'total_samples' in name:
         var = tf.get_variable(name,
                               shape,
                               dtype,
                               tf.constant_initializer(0, dtype=dtype),
                               collections=collections,
                               trainable=trainable,
                               validate_shape=True)
     else:
         key = hash(tuple(shape))
         if key in variable_cache:
             var = variable_cache[key]
         else:
             var = tf.zeros(shape, dtype)
             variable_cache[key] = var
     return var, var
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                outputs = cls.GetEinSumResult(inputs, proj_obj)
        else:
            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            if p.pruning_hparams_dict['compression_option'] == 10:
                if p.pruning_hparams_dict['block_method'] == 'mask':
                    intermediate_result = py_utils.Matmul(
                        compressed_inputs,
                        tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar))
                elif p.pruning_hparams_dict['block_method'] == 'loop':
                    num_blocks = p.pruning_hparams_dict[
                        'block_compression_factor']
                    input_splitted = tf.split(compressed_inputs,
                                              num_blocks,
                                              axis=-1)
                    output_splitted = []
                    for i, input_i in enumerate(input_splitted):
                        output_splitted.append(
                            py_utils.Matmul(input_i,
                                            theta.c_matrix_tfvar[i, :, :]))
                    intermediate_result = tf.concat(output_splitted, axis=-1)
            else:
                intermediate_result = py_utils.Matmul(compressed_inputs,
                                                      theta.c_matrix_tfvar)

            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
                r = inputs.shape.rank
                outputs = cls.GetEinSumResult(
                    inputs, weight, '{0}y,yz->{0}z'.format(s[:r - 1]),
                    proj_obj)
        else:
            if p.pruning_hparams_dict['compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            intermediate_result = py_utils.Matmul(compressed_inputs,
                                                  theta.c_matrix_tfvar)

            if p.pruning_hparams_dict['compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs