コード例 #1
0
    def test_splice(self):
        ''' test batch splice frame '''
        with self.session():
            feat = tf.ones([1, 3, 2], dtype=tf.float32)

            for l_ctx in range(0, 4):
                for r_ctx in range(0, 4):
                    ctx = l_ctx + 1 + r_ctx
                    out = tffeat.splice(feat,
                                        left_context=l_ctx,
                                        right_context=r_ctx)
                    self.assertTupleEqual(out.eval().shape, (1, 3, 2 * ctx))
                    self.assertAllEqual(out, tf.ones([1, 3, 2 * ctx]))

            with self.assertRaises(ValueError):
                out = tffeat.splice(feat, left_context=-2,
                                    right_context=-2).eval()

            with self.assertRaises(ValueError):
                out = tffeat.splice(feat, left_context=2,
                                    right_context=-2).eval()

            with self.assertRaises(ValueError):
                out = tffeat.splice(feat, left_context=-2,
                                    right_context=2).eval()
コード例 #2
0
ファイル: common_layers.py プロジェクト: liuweiping2020/delta
def tdnn(x,
         name,
         in_dim,
         context,
         out_dim,
         has_bias=True,
         method='splice_layer'):
    '''
  TDNN implementation.

  Args:
    context:
      a int of left and right context, or
      a list of context indexes, e.g. (-2, 0, 2).
    method:
      splice_layer: use column-first patch-based copy.
      splice_op: use row-first while_loop copy.
      conv1d: use conv1d as TDNN equivalence.
  '''
    if hasattr(context, '__iter__'):
        context_size = len(context)
        if method in ('splice_op', 'conv1d'):
            msg = 'Method splice_op and conv1d does not support context list.'
            raise ValueError(msg)
        context_list = context
    else:
        context_size = context * 2 + 1
        context_list = range(-context, context + 1)
    with tf.variable_scope(name):
        if method == 'splice_layer':
            x = splice_layer(x, 'splice', context_list)
            x = linear(x,
                       'linear', [in_dim * context_size, out_dim],
                       has_bias=has_bias)
        elif method == 'splice_op':
            x = speech_ops.splice(x, context, context)
            x = linear(x,
                       'linear', [in_dim * context_size, out_dim],
                       has_bias=has_bias)
        elif method == 'conv1d':
            kernel = tf.get_variable(
                name='DW',
                shape=[context, in_dim, out_dim],
                dtype=tf.float32,
                initializer=tf.contrib.layers.xavier_initializer())
            x = tf.nn.conv1d(x, kernel, stride=1, padding='SAME')
            if has_bias:
                b = tf.get_variable(name='bias',
                                    shape=[out_dim],
                                    dtype=tf.float32,
                                    initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, b)
        else:
            raise ValueError('Unsupported method: %s.' % (method))
        return x
コード例 #3
0
ファイル: speech_ops_test.py プロジェクト: zhankm/delta
  def test_splice(self):
    ''' test batch splice frame '''
    with self.cached_session(use_gpu=False, force_gpu=False):
      feat = tf.ones([1, 3, 2], dtype=tf.float32)

      for l_ctx in range(0, 4):
        for r_ctx in range(0, 4):
          ctx = l_ctx + 1 + r_ctx
          out = tffeat.splice(feat, left_context=l_ctx, right_context=r_ctx)
          self.assertTupleEqual(out.eval().shape, (1, 3, 2 * ctx))
          self.assertAllEqual(out, tf.ones([1, 3, 2 * ctx]))

      with self.assertRaises(tf.errors.InvalidArgumentError):
        out = tffeat.splice(feat, left_context=-2, right_context=-2).eval()

      with self.assertRaises(tf.errors.InvalidArgumentError):
        out = tffeat.splice(feat, left_context=2, right_context=-2).eval()

      with self.assertRaises(tf.errors.InvalidArgumentError):
        out = tffeat.splice(feat, left_context=-2, right_context=2).eval()