示例#1
0
def Elman(theta, state0, inputs):
    h0, w, b, x = state0.h, theta.w, theta.b, inputs.x
    xw = py_utils.Matmul(tf.concat([x, h0], axis=1), w)  # 1st part
    # 2nd part
    padding = inputs.get('padding', None)
    h1 = _ApplyPadding(padding, v_no_pad=tf.sigmoid(xw + b), v_pad=state0.h)

    state1 = py_utils.NestedMap(h=h1)
    if padding is not None:
        state1.padding = inputs.padding

    return (state1, py_utils.NestedMap(h=h1))
示例#2
0
    def Grad(h0, w, b, x, padding, h1, dh1):
        del b
        dh1_orig = dh1
        dh1 = _ApplyPadding(padding, dh1, tf.zeros_like(dh1, dtype=dh1.dtype))

        # We hand-roll the gradient for the 2nd half of the cell as a demo.
        # h1 = tf.sigmoid(xw + b)
        # 𝛔'(x) = ((1 - 𝛔(x)) * 𝛔(x))
        dxwb = (dh1 * (1 - h1) * h1)
        dxw, db = dxwb, tf.reduce_sum(dxwb, axis=0)

        # Uses tf.gradient for the 1nd half of the cell as a demo.
        xw = py_utils.Matmul(tf.concat([x, h0], axis=1), w)
        dh0, dx, dw = tf.gradients(ys=[xw], xs=[h0, x, w], grad_ys=[dxw])

        dh0 = _ApplyPadding(padding, dh0, dh1_orig)

        return dh0, dx, dw, db
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                outputs = cls.GetEinSumResult(inputs, proj_obj)
        else:
            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            if p.pruning_hparams_dict['compression_option'] == 10:
                if p.pruning_hparams_dict['block_method'] == 'mask':
                    intermediate_result = py_utils.Matmul(
                        compressed_inputs,
                        tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar))
                elif p.pruning_hparams_dict['block_method'] == 'loop':
                    num_blocks = p.pruning_hparams_dict[
                        'block_compression_factor']
                    input_splitted = tf.split(compressed_inputs,
                                              num_blocks,
                                              axis=-1)
                    output_splitted = []
                    for i, input_i in enumerate(input_splitted):
                        output_splitted.append(
                            py_utils.Matmul(input_i,
                                            theta.c_matrix_tfvar[i, :, :]))
                    intermediate_result = tf.concat(output_splitted, axis=-1)
            else:
                intermediate_result = py_utils.Matmul(compressed_inputs,
                                                      theta.c_matrix_tfvar)

            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
                r = inputs.shape.rank
                outputs = cls.GetEinSumResult(
                    inputs, weight, '{0}y,yz->{0}z'.format(s[:r - 1]),
                    proj_obj)
        else:
            if p.pruning_hparams_dict['compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            intermediate_result = py_utils.Matmul(compressed_inputs,
                                                  theta.c_matrix_tfvar)

            if p.pruning_hparams_dict['compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs