def test_splice(self): ''' test batch splice frame ''' with self.session(): feat = tf.ones([1, 3, 2], dtype=tf.float32) for l_ctx in range(0, 4): for r_ctx in range(0, 4): ctx = l_ctx + 1 + r_ctx out = tffeat.splice(feat, left_context=l_ctx, right_context=r_ctx) self.assertTupleEqual(out.eval().shape, (1, 3, 2 * ctx)) self.assertAllEqual(out, tf.ones([1, 3, 2 * ctx])) with self.assertRaises(ValueError): out = tffeat.splice(feat, left_context=-2, right_context=-2).eval() with self.assertRaises(ValueError): out = tffeat.splice(feat, left_context=2, right_context=-2).eval() with self.assertRaises(ValueError): out = tffeat.splice(feat, left_context=-2, right_context=2).eval()
def tdnn(x, name, in_dim, context, out_dim, has_bias=True, method='splice_layer'): ''' TDNN implementation. Args: context: a int of left and right context, or a list of context indexes, e.g. (-2, 0, 2). method: splice_layer: use column-first patch-based copy. splice_op: use row-first while_loop copy. conv1d: use conv1d as TDNN equivalence. ''' if hasattr(context, '__iter__'): context_size = len(context) if method in ('splice_op', 'conv1d'): msg = 'Method splice_op and conv1d does not support context list.' raise ValueError(msg) context_list = context else: context_size = context * 2 + 1 context_list = range(-context, context + 1) with tf.variable_scope(name): if method == 'splice_layer': x = splice_layer(x, 'splice', context_list) x = linear(x, 'linear', [in_dim * context_size, out_dim], has_bias=has_bias) elif method == 'splice_op': x = speech_ops.splice(x, context, context) x = linear(x, 'linear', [in_dim * context_size, out_dim], has_bias=has_bias) elif method == 'conv1d': kernel = tf.get_variable( name='DW', shape=[context, in_dim, out_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) x = tf.nn.conv1d(x, kernel, stride=1, padding='SAME') if has_bias: b = tf.get_variable(name='bias', shape=[out_dim], dtype=tf.float32, initializer=tf.constant_initializer(0.0)) x = tf.nn.bias_add(x, b) else: raise ValueError('Unsupported method: %s.' % (method)) return x
def test_splice(self): ''' test batch splice frame ''' with self.cached_session(use_gpu=False, force_gpu=False): feat = tf.ones([1, 3, 2], dtype=tf.float32) for l_ctx in range(0, 4): for r_ctx in range(0, 4): ctx = l_ctx + 1 + r_ctx out = tffeat.splice(feat, left_context=l_ctx, right_context=r_ctx) self.assertTupleEqual(out.eval().shape, (1, 3, 2 * ctx)) self.assertAllEqual(out, tf.ones([1, 3, 2 * ctx])) with self.assertRaises(tf.errors.InvalidArgumentError): out = tffeat.splice(feat, left_context=-2, right_context=-2).eval() with self.assertRaises(tf.errors.InvalidArgumentError): out = tffeat.splice(feat, left_context=2, right_context=-2).eval() with self.assertRaises(tf.errors.InvalidArgumentError): out = tffeat.splice(feat, left_context=-2, right_context=2).eval()