def FProp(self, theta, inputs, paddings):
    """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
    p = self.params
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(inputs),
              tf.concat([
                  tf.shape(paddings),
                  [-1, symbolic.ToStatic(self.input_channels)]
              ], 0))
      ], inputs)

      def _ApplyPadding(tensor_in, padding_in):
        padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1)
        return tensor_in * (1.0 - padding_expanded)

      # Zeroing out padded inputs.
      inputs = _ApplyPadding(inputs, paddings)

      # Apply conv on 'inputs'.
      out = self._ApplyConv(theta, inputs)

      # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
      # But there's likely no real problems. Trying to set it gives an error:
      # pooling with SAME padding is not implemented for dilation_rate > 1.
      # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
      # implementation.  Consider updating it to be the actual shape.
      conv_padding = ComputeConvOutputPadding(
          paddings, window=p.filter_stride[0], stride=p.filter_stride[0])
      # Assuming padded nodes will be properly zero-ed out if necessary by
      # sub-sequent layers.
      # out = _ApplyPadding(out, conv_padding)
      out = py_utils.HasShape(
          out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
      return out, conv_padding
Esempio n. 2
0
  def testEvalExpr(self):
    x = symbolic.Symbol('x')
    y = symbolic.Symbol('y')
    xy = x * y

    # Without symbol-to-value map.
    self.assertEqual(xy, symbolic.ToStatic(xy))
    self.assertEqual(xy, symbolic.ToTensor(xy))

    with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}):
      self.assertEqual(symbolic.ToStatic(xy), 6)
      # The inner map overrides the outer map.
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 5, y: 6}):
        self.assertEqual(symbolic.ToStatic(xy), 30)
      # Back to the outer map.
      self.assertEqual(symbolic.ToStatic(xy), 6)

    # EvalExpr can also evaluate a symbolic expression to a
    # Tensor.
    a = tf.placeholder(tf.float32)
    b = tf.placeholder(tf.float32)
    with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: a, y: b}):
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}):
        # Value maps of different types do not affect each other.
        self.assertEqual(symbolic.ToStatic(xy), 6)
        ab = symbolic.ToTensor(xy)
        self.assertIsInstance(ab, tf.Tensor)
        with self.session() as sess:
          self.assertEqual(12, sess.run(ab, {a: 3, b: 4}))

      # EvalExpr supports partial evaluation.
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {y: 3}):
        x3 = symbolic.ToStatic(xy)
        with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 9}):
          self.assertEqual(27, symbolic.ToStatic(x3))
Esempio n. 3
0
    def testToFromProto(self):
        outer = hyperparams.Params()
        outer.Define('integer_val', 1, '')
        outer.Define('cls_type', type(int), '')
        inner = hyperparams.Params()
        inner.Define('float_val', 2.71, '')
        inner.Define('string_val', 'rosalie et adrien', '')
        inner.Define('bool_val', True, '')
        inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '')
        inner.Define('range', range(1, 3), '')
        outer.Define('inner', inner, '')
        outer.Define('empty_list', [], '')
        outer.Define('empty_tuple', (), '')
        outer.Define('empty_dict', {}, '')
        outer.Define('enum', TestEnum.B, '')
        outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')
        outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '')
        outer.Define('namedtuple',
                     tf.io.FixedLenSequenceFeature([42], tf.float32), '')
        outer.Define('symbol_x', symbolic.Symbol('x'), '')
        outer.Define('symbol_2x', outer.symbol_x * 2, '')

        rebuilt_outer = hyperparams.InstantiableParams.FromProto(
            outer.ToProto())

        self.assertNotIn('cls', rebuilt_outer)
        self.assertEqual(outer.integer_val, rebuilt_outer.integer_val)
        self.assertEqual(outer.cls_type, rebuilt_outer.cls_type)
        self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val,
                        1e-6)
        self.assertEqual(outer.inner.string_val,
                         rebuilt_outer.inner.string_val)
        self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val)
        self.assertEqual(outer.inner.list_of_tuples_of_dicts,
                         rebuilt_outer.inner.list_of_tuples_of_dicts)
        self.assertEqual([1, 2], rebuilt_outer.inner.range)  # Rebuilt as list.
        self.assertEqual(outer.empty_list, rebuilt_outer.empty_list)
        self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple)
        self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict)
        self.assertEqual(outer.enum, rebuilt_outer.enum)
        self.assertEqual(outer.proto, rebuilt_outer.proto)
        self.assertEqual(outer.dataclass, rebuilt_outer.dataclass)
        self.assertEqual(outer.namedtuple, rebuilt_outer.namedtuple)

        with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES,
                                       {rebuilt_outer.symbol_x: 42}):
            self.assertEqual(symbolic.ToStatic(rebuilt_outer.symbol_2x), 84)
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                outputs = cls.GetEinSumResult(inputs, proj_obj)
        else:
            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            if p.pruning_hparams_dict['compression_option'] == 10:
                if p.pruning_hparams_dict['block_method'] == 'mask':
                    intermediate_result = py_utils.Matmul(
                        compressed_inputs,
                        tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar))
                elif p.pruning_hparams_dict['block_method'] == 'loop':
                    num_blocks = p.pruning_hparams_dict[
                        'block_compression_factor']
                    input_splitted = tf.split(compressed_inputs,
                                              num_blocks,
                                              axis=-1)
                    output_splitted = []
                    for i, input_i in enumerate(input_splitted):
                        output_splitted.append(
                            py_utils.Matmul(input_i,
                                            theta.c_matrix_tfvar[i, :, :]))
                    intermediate_result = tf.concat(output_splitted, axis=-1)
            else:
                intermediate_result = py_utils.Matmul(compressed_inputs,
                                                      theta.c_matrix_tfvar)

            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
                r = inputs.shape.rank
                outputs = cls.GetEinSumResult(
                    inputs, weight, '{0}y,yz->{0}z'.format(s[:r - 1]),
                    proj_obj)
        else:
            if p.pruning_hparams_dict['compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            intermediate_result = py_utils.Matmul(compressed_inputs,
                                                  theta.c_matrix_tfvar)

            if p.pruning_hparams_dict['compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs