def _to_tensor_array(name, v, clear_after_read=None): ''' create TensorArray from v, of size batch.''' ta = tf.TensorArray(v.dtype, batch, name=name, clear_after_read=clear_after_read) ta = ta.unstack(v) return ta
def splice_layer(x, name, context): ''' Splice a tensor along the last dimension with context. e.g.: t = [[[1, 2, 3], [4, 5, 6], [7, 8, 9]]] splice_tensor(t, [0, 1]) = [[[1, 2, 3, 4, 5, 6], [4, 5, 6, 7, 8, 9], [7, 8, 9, 7, 8, 9]]] Args: tensor: a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W) context: a list of context offsets Returns: spliced tensor with shape (..., D * len(context)) ''' with tf.variable_scope(name): input_shape = tf.shape(x) B, T = input_shape[0], input_shape[1] context_len = len(context) array = tf.TensorArray(x.dtype, size=context_len) for idx, offset in enumerate(context): begin = offset end = T + offset if begin < 0: begin = 0 sliced = x[:, begin:end, :] tiled = tf.tile(x[:, 0:1, :], [1, abs(offset), 1]) final = tf.concat((tiled, sliced), axis=1) else: end = T sliced = x[:, begin:end, :] tiled = tf.tile(x[:, -1:, :], [1, abs(offset), 1]) final = tf.concat((sliced, tiled), axis=1) array = array.write(idx, final) spliced = array.stack() spliced = tf.transpose(spliced, (1, 2, 0, 3)) spliced = tf.reshape(spliced, (B, T, -1)) return spliced
def _new_tensor_array(name, size, dtype=None): ''' create empty TensorArray which can store size elements.''' return tf.TensorArray(dtype, size, name=name)