Example #1
0
def cnn_attention(x, channels, scope='attention'):
    with tf.name_scope(scope):
        ori_shape = shape_list(x)
        max_pooling = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same")

        out_f = conv_sa(x, channels // 8, kernel=(1, 1), strides=(1, 1), scope="f_attn_conv")
        out_f = max_pooling(out_f)

        out_g = conv_sa(x, channels // 8, kernel=(1, 1), strides=(1, 1), scope='g_attn_conv')

        out_h = conv_sa(x, channels // 2, kernel=(1, 1), strides=(1, 1), scope='h_attn_conv')
        out_h = max_pooling(out_h)

        shape_f = shape_list(out_f)
        shape_g = shape_list(out_g)

        flatten_f = tf.reshape(out_f, shape=[-1, shape_f[1] * shape_f[2], shape_f[3]])
        flatten_g = tf.reshape(out_g, shape=[-1, shape_g[1] * shape_g[2], shape_g[3]])

        attn_out = tf.linalg.matmul(flatten_g, flatten_f, transpose_b=True)
        attn_matrix = tf.keras.activations.softmax(attn_out)

        shape_h = shape_list(out_h)
        flatten_h = tf.reshape(out_h, shape=[-1, shape_h[1] * shape_h[2], shape_h[3]])
        attn_out_2 = tf.linalg.matmul(attn_matrix, flatten_h)

        out = tf.reshape(attn_out_2, shape=[-1, ori_shape[1], ori_shape[2], ori_shape[3] // 2])
        out = conv_sa(out, channels, kernel=(1, 1), strides=(1, 1), scope='out_attn_conv')

        gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
        x_att = gamma*out + x  # noqa: E226
        x_gma = gamma * out

        return x_att, x_gma
Example #2
0
    def call(self, inp, encoder_input_emb, chord_change_pred):
        segment_encodings = self.encode_segment_frequency(inp)
        segment_encodings_blocked, block_ids = chord_block_compression(
            segment_encodings, chord_change_pred)
        segment_encodings_blocked = chord_block_decompression(
            segment_encodings_blocked, block_ids)
        segment_encodings_blocked.set_shape(
            [None, self.n_steps, self.dec_input_emb_size])

        decoder_inputs = segment_encodings + segment_encodings_blocked + encoder_input_emb
        decoder_inputs += positional_encoding(
            batch_size=shape_list(decoder_inputs)[0],
            timesteps=self.n_steps,
            n_units=self.dec_input_emb_size)

        decoder_inputs_drop = self.dropout(decoder_inputs)
        layer_weights = tf.nn.softmax(tf.zeros((self.num_attn_blocks)))
        weighted_hiddens_dec = tf.zeros(shape=shape_list(segment_encodings))
        layer_stack = zip(self.attn_layers_1, self.attn_layers_2,
                          self.ff_layers)
        for idx, (attn_1, attn_2, feed_forward) in enumerate(layer_stack):
            decoder_inputs_drop = attn_1(q=decoder_inputs_drop,
                                         k=decoder_inputs_drop,
                                         v=decoder_inputs_drop)
            decoder_inputs_drop = attn_2(q=decoder_inputs_drop,
                                         k=encoder_input_emb,
                                         v=encoder_input_emb)
            decoder_inputs_drop = feed_forward(decoder_inputs_drop)
            weighted_hiddens_dec += layer_weights[idx] * decoder_inputs_drop

        logits = self.out_dense(weighted_hiddens_dec)
        chord_pred = tf.argmax(input=logits, axis=-1, output_type=tf.int32)
        return logits, chord_pred
Example #3
0
def pad_to_multiple_2d(x, block_shape):
    """Making sure x is a multiple of shape.

    Parameters
    ----------
    x
        A [batch, heads, h, w, depth] or [batch, h, w, depth] tensor
    block_shape
        A 2D list of integer shapes

    Returns
    -------
    padded_x
        A [batch, heads, h, w, depth] or [batch, h, w, depth] tensor
    """
    old_shape = x.get_shape().dims
    last = old_shape[-1]
    if len(old_shape) == 4:
        height_padding = -shape_list(x)[1] % block_shape[0]
        width_padding = -shape_list(x)[2] % block_shape[1]
        paddings = [[0, 0], [0, height_padding], [0, width_padding], [0, 0]]
    elif len(old_shape) == 5:
        height_padding = -shape_list(x)[2] % block_shape[0]
        width_padding = -shape_list(x)[3] % block_shape[1]
        paddings = [[0, 0], [0, 0], [0, height_padding], [0, width_padding],
                    [0, 0]]

    padded_x = tf.pad(x, paddings)
    padded_shape = padded_x.get_shape().as_list()
    padded_shape = padded_shape[:-1] + [last]
    padded_x.set_shape(padded_shape)
    return padded_x
Example #4
0
    def call(self, q, k, v):  # pylint: disable=W0221
        q_emb = self.q_emb_dense(q)
        k_emb = self.k_emb_dense(k)
        v_emb = self.v_emb_dense(v)

        q_heads = tf.concat(tf.split(q_emb, self.n_heads, 2), 0)
        k_heads = tf.concat(tf.split(k_emb, self.n_heads, 2), 0)
        v_heads = tf.concat(tf.split(v_emb, self.n_heads, 2), 0)

        attn_weights = tf.matmul(q_heads, tf.transpose(k_heads, perm=[0, 2,
                                                                      1]))
        if self.relative_position:
            tk, dk = shape_list(k_heads)[1:]
            rel_pos_enc_k = relative_positional_encoding(
                n_steps=tk, n_units=dk, max_dist=self.max_dist)
            rel_pos_enc_k = tf.matmul(tf.transpose(a=q_heads, perm=[1, 0, 2]),
                                      rel_pos_enc_k,
                                      transpose_b=True)
            rel_pos_enc_k = tf.transpose(a=rel_pos_enc_k, perm=[1, 0, 2])
            attn_weights += rel_pos_enc_k

        scaled_attn_weights = attn_weights / shape_list(k_heads)[-1]**0.5
        if self.causal:
            diag_vals = tf.ones_like(scaled_attn_weights[0])
            tril_mask = tf.linalg.LinearOperatorLowerTriangular(
                diag_vals).to_dense()
            tril_paddings = tf.ones_like(tril_mask) * (-2**32 + 1)
            tril_masking = lambda x: tf.where(tril_mask == 0, tril_paddings, x)
            scaled_attn_weights = tf.map_fn(tril_masking, scaled_attn_weights)

        if self.self_mask:
            diag = tf.zeros_like(scaled_attn_weights[:, :, 0])
            scaled_attn_weights = tf.linalg.set_diag(input=scaled_attn_weights,
                                                     diagonal=diag)

        exp_attn_weights = tf.nn.softmax(scaled_attn_weights)
        exp_attn_weights = self.dropout(exp_attn_weights)

        outputs = tf.matmul(exp_attn_weights, v_heads)
        if self.relative_position:
            tv, dv = shape_list(v_heads)[1:]
            rel_pos_enc_v = relative_positional_encoding(
                n_steps=tv, n_units=dv, max_dist=self.max_dist)
            rel_pos_enc_v = tf.matmul(
                tf.transpose(a=exp_attn_weights, perm=[1, 0, 2]),
                rel_pos_enc_v)
            rel_pos_enc_v = tf.transpose(a=rel_pos_enc_v, perm=[1, 0, 2])
            outputs += rel_pos_enc_v

        outputs = tf.concat(tf.split(outputs, self.n_heads, 0),
                            2)  # Restore shape
        outputs = self.out_dense(outputs)
        outputs += q  # Residual connection
        return self.layer_norm(outputs)
Example #5
0
def scatter_blocks_2d(x, indices, shape):
    """scatters blocks from x into shape with indices."""
    x_shape = shape_list(x)
    # [length, batch, heads, dim]
    x_t = tf.transpose(
        tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3])
    x_t_shape = shape_list(x_t)
    indices = tf.reshape(indices, [-1, 1])
    scattered_x = tf.scatter_nd(indices, x_t, x_t_shape)
    scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3])
    return tf.reshape(scattered_x, shape)
Example #6
0
def conv_sa(
    x,
    channels,
    kernel=(4, 4),
    strides=(2, 2),
    pad=0,
    pad_type="zero",
    spectral_norm=True,
    scope="conv_0"
):
    with tf.name_scope(scope):
        if pad > 0:
            height = shape_list(x)[1]
            if height % strides[1] == 0:
                pad *= 2
            else:
                pad = max(kernel[1] - (height % strides[1]), 0)

            pad_top = pad // 2
            pad_bottom = pad - pad_top
            pad_left = pad // 2
            pad_right = pad - pad_left

            if pad_type == 'zero':
                x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
            if pad_type == 'reflect':
                x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')

        if spectral_norm:
            return ConvSN2D(channels, kernel_size=kernel, strides=strides, scope=f"{scope}_SN")(x)
        return tf.keras.layers.Conv2D(
            channels, kernel_size=kernel, strides=strides, padding='same', name=f"{scope}_conv"
        )(x)
Example #7
0
    def call(self, inp):
        # output dim: [batch_size*n_steps, tonal_size, segment_width]
        inp_reshape = tf.reshape(
            inp, shape=[-1, self.freq_size, self.segment_width])

        # output dim: [batch_size*n_steps, segment_width, tonal_size]
        inp_permute = tf.transpose(a=inp_reshape, perm=[0, 2, 1])
        inp_permute += positional_encoding(
            batch_size=shape_list(inp_permute)[0],
            timesteps=self.segment_width,
            n_units=self.freq_size) * 0.01 + 0.01

        attn_output = self.attn_layer(q=inp_permute,
                                      k=inp_permute,
                                      v=inp_permute)
        forward_output = self.feed_forward(attn_output)

        # restore shape
        outputs = tf.transpose(a=forward_output, perm=[0, 2, 1])
        outputs = tf.reshape(
            outputs,
            shape=[-1, self.n_steps, self.freq_size * self.segment_width])

        outputs = self.dropout(outputs)
        outputs = self.dense(outputs)
        return self.layer_norm(outputs)
Example #8
0
def gather_blocks_2d(x, indices):
    """Gathers flattened blocks from x."""
    x_shape = shape_list(x)
    x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])])
    # [length, batch, heads, dim]
    x_t = tf.transpose(x, [2, 0, 1, 3])
    x_new = tf.gather(x_t, indices)
    # returns [batch, heads, num_blocks, block_length ** 2, dim]
    return tf.transpose(x_new, [2, 3, 0, 1, 4])
Example #9
0
def gather_indices_2d(x, block_shape, block_stride):
    """Getting gather indices."""
    # making an identity matrix kernel
    kernel = tf.eye(block_shape[0] * block_shape[1])
    kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1])
    # making indices [1, h, w, 1] to appy convs
    x_shape = shape_list(x)
    indices = tf.range(x_shape[2] * x_shape[3])
    indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1])
    indices = tf.nn.conv2d(tf.cast(indices, tf.float32),
                           kernel,
                           strides=[1, block_stride[0], block_stride[1], 1],
                           padding="VALID")
    # making indices [num_blocks, dim] to gather
    dims = shape_list(indices)[:3]
    if all([isinstance(dim, int) for dim in dims]):
        num_blocks = functools.reduce(operator.mul, dims, 1)
    else:
        num_blocks = tf.reduce_prod(dims)
    indices = tf.reshape(indices, [num_blocks, -1])
    return tf.cast(indices, tf.int32)
Example #10
0
    def call(self, inp, slope=1):
        segment_encodings = self.encode_segment_time(inp)
        segment_encodings += positional_encoding(
            batch_size=shape_list(segment_encodings)[0],
            timesteps=self.n_steps,
            n_units=self.enc_input_emb_size)
        segment_encodings = self.dropout(segment_encodings)

        weight = tf.nn.softmax(self.layer_weights)
        weighted_hidden_enc = tf.zeros(shape=shape_list(segment_encodings))
        for idx, (attn_layer, feed_forward) in enumerate(
                zip(self.attn_layers, self.ff_layers)):
            segment_encodings = attn_layer(q=segment_encodings,
                                           k=segment_encodings,
                                           v=segment_encodings)
            segment_encodings = feed_forward(segment_encodings)
            weighted_hidden_enc += weight[idx] * segment_encodings

        chord_change_logits = tf.squeeze(self.logit_dense(weighted_hidden_enc))
        chord_change_prob = tf.sigmoid(slope * chord_change_logits)
        chord_change_pred = binary_round(chord_change_prob, cast_to_int=True)

        return weighted_hidden_enc, chord_change_logits, chord_change_pred
Example #11
0
def combine_last_two_dimensions(x):
    """Reshape x so that the last two dimension become one.

    Parameters
    ----------
    x
        A Tensor with shape [..., a, b]

    Returns
    -------
    y
        A Tensor with shape [..., ab]
    """
    x_shape = shape_list(x)
    a, b = x_shape[-2:]
    return tf.reshape(x, x_shape[:-2] + [a * b])  # noqa: E226
Example #12
0
    def call(self, inp):
        inp_reshape = tf.reshape(inp, [-1, self.freq_size, self.segment_width])
        inp_reshape += positional_encoding(
            batch_size=shape_list(inp_reshape)[0],
            timesteps=self.freq_size,
            n_units=self.segment_width) * 0.01 + 0.01

        attn_output = self.attn_layer(q=inp_reshape,
                                      k=inp_reshape,
                                      v=inp_reshape)
        forward_output = self.feed_forward(attn_output)

        # restore shape
        outputs = tf.reshape(
            forward_output,
            shape=[-1, self.n_steps, self.freq_size * self.segment_width])

        outputs = self.dropout(outputs)
        outputs = self.out_dense(outputs)
        return self.layer_norm(outputs)
Example #13
0
def transpose_residual_block(x, channels, to_down=True, spectral_norm=True, scope='transblock'):
    with tf.name_scope(scope):
        init_channel = shape_list(x)[-1]
        with tf.name_scope('res1'):
            out = tf.keras.layers.ELU()(x)
            out = conv_sa(
                out,
                channels,
                kernel=(3, 3),
                strides=(1, 1),
                pad=1,
                pad_type='reflect',
                spectral_norm=spectral_norm,
                scope=f"{scope}_res1"
            )

        with tf.name_scope('res2'):
            out = tf.keras.layers.ELU()(out)
            out = conv_sa(
                out,
                channels,
                kernel=(3, 3),
                strides=(1, 1),
                pad=1,
                pad_type='reflect',
                spectral_norm=spectral_norm,
                scope=f"{scope}_res2"
            )
            if to_down:
                out = down_sample(out)

        if to_down or init_channel != channels:
            with tf.name_scope('shortcut'):
                x = conv_sa(
                    x, channels, kernel=(1, 1), strides=(1, 1), spectral_norm=spectral_norm, scope=f"{scope}_shortcut"
                )
                if to_down:
                    x = down_sample(x)

        return out + x
Example #14
0
def split_last_dimension(x, n):
    """Reshape x so that the last dimension becomes two dimensions.

    The first of these two dimensions is n.

    Parameters
    ----------
    x
        A Tensor with shape [..., m]
    n: int
        An integer.

    Returns
    -------
    y
        A Tensor with shape [..., n, m/n]
    """
    x_shape = shape_list(x)
    m = x_shape[-1]
    if isinstance(m, int) and isinstance(n, int):
        assert m % n == 0
    return tf.reshape(x, x_shape[:-1] + [n, m // n])
Example #15
0
    def build(self, input_shape):
        self.layer.build(input_shape)

        self.w = self.layer.kernel
        self.w_shape = shape_list(self.w)

        self.v = self.add_weight(
            shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
            initializer=tf.initializers.TruncatedNormal(stddev=0.02),
            trainable=False,
            name='sn_v',
            dtype=tf.float32
        )

        self.u = self.add_weight(
            shape=(1, self.w_shape[-1]),
            initializer=tf.initializers.TruncatedNormal(stddev=0.02),
            trainable=False,
            name='sn_u',
            dtype=tf.float32
        )

        super().build()
Example #16
0
def local_attention_2d(q,
                       k,
                       v,
                       query_shape=(8, 16),
                       memory_flange=(8, 16),
                       name=None):
    """Strided block local self-attention.

    The 2-D sequence is divided into 2-D blocks of shape query_shape. Attention
    for a given query position can only see memory positions less than or equal to
    the query position. The memory positions are the corresponding block with
    memory_flange many positions to add to the height and width of the block
    (namely, left, top, and right).

    Parameters
    ----------
    q
        A tensor with shape [batch, heads, h, w, depth_k]
    k
        A tensor with shape [batch, heads, h, w, depth_k]
    v
        A tensor with shape [batch, heads, h, w, depth_v]. In the current
        implementation, depth_v must be equal to depth_k.
    query_shape: tuple
        An tuple indicating the height and width of each query block.
    memory_flange: tuple
        An integer indicating how much to look in height and width
        from each query block.
    name: str
        An optional string

    Returns
    -------
    y
        A Tensor of shape [batch, heads, h, w, depth_v]
  """
    with tf.compat.v1.variable_scope(name,
                                     default_name="local_self_attention_2d",
                                     values=[q, k, v]):
        v_shape = shape_list(v)

        # Pad query, key, value to ensure multiple of corresponding lengths.
        q = pad_to_multiple_2d(q, query_shape)
        k = pad_to_multiple_2d(k, query_shape)
        v = pad_to_multiple_2d(v, query_shape)
        paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]],
                    [memory_flange[0], memory_flange[1]], [0, 0]]
        k = tf.pad(k, paddings)
        v = tf.pad(v, paddings)

        # Set up query blocks.
        q_indices = gather_indices_2d(q, query_shape, query_shape)
        q_new = gather_blocks_2d(q, q_indices)

        # Set up key and value blocks.
        memory_shape = (query_shape[0] + 2 * memory_flange[0],
                        query_shape[1] + 2 * memory_flange[1])
        k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape)
        k_new = gather_blocks_2d(k, k_and_v_indices)
        v_new = gather_blocks_2d(v, k_and_v_indices)

        attention_bias = tf.expand_dims(
            tf.compat.v1.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)
        output = dot_product_attention(q_new,
                                       k_new,
                                       v_new,
                                       attention_bias,
                                       dropout_rate=0.0,
                                       name="local_2d")
        # Put representations back into original shapes.
        padded_q_shape = shape_list(q)
        output = scatter_blocks_2d(output, q_indices, padded_q_shape)

        # Remove the padding if introduced.
        output = tf.slice(output, [0, 0, 0, 0, 0],
                          [-1, -1, v_shape[2], v_shape[3], -1])
        return output
Example #17
0
def reshape_range(tensor, i, j, shape):
    """Reshapes a tensor between dimensions i and j."""
    t_shape = shape_list(tensor)
    target_shape = t_shape[:i] + shape + t_shape[j:]
    return tf.reshape(tensor, target_shape)
Example #18
0
def drum_model(out_classes, mini_beat_per_seg, res_block_num=3, channels=64, spectral_norm=True):
    """Get the drum transcription model.

    Constructs the drum transcription model instance for training/inference.

    Parameters
    ----------
    out_classes: int
        Total output classes, refering to classes of drum types.
        Currently there are 13 pre-defined drum percussions.
    mini_beat_per_seg: int
        Number of mini beats in a segment. Can be understood as the range of time
        to be considered for training.
    res_block_num: int
        Number of residual blocks.

    Returns
    -------
    model: tf.keras.Model
        A tensorflow keras model instance.
    """
    with tf.name_scope('transcription_model'):
        inp_wrap = tf.keras.Input(shape=(120, 120, mini_beat_per_seg), name="input_tensor")
        input_tensor = inp_wrap * tf.constant(100)

        padded_input = tf.pad(input_tensor, [[0, 0], [0, 0], [1, 0], [0, 0]], name='tf_diff2_pady')[:, :, :-1]
        pad_out = input_tensor - padded_input
        pad_out = tf.concat([tf.zeros_like(pad_out)[:, :, :1], pad_out[:, :, 1:]], axis=2)
        pad_out = tf.concat([input_tensor, pad_out], axis=-1, name='tf_diff2_concat')

        out = residual_block(pad_out, channels=channels, spectral_norm=spectral_norm, scope='init_resbk')
        out = transpose_residual_block(out, channels=channels, spectral_norm=spectral_norm, scope='fd_resbk')
        x_att, _ = cnn_attention(out, channels=channels, scope='self_attn')

        for i in range(res_block_num):
            if i == 0:
                # first res layer
                out_2 = transpose_residual_block(
                    x_att, channels=channels, spectral_norm=spectral_norm, scope=f"md_resbk_{i}"
                )
            elif i != res_block_num - 1:
                # middle res layer
                out_2 = transpose_residual_block(
                    out_2, channels=channels, spectral_norm=spectral_norm, scope=f"md_resbk_{i}"
                )
            else:
                # last res layer
                out_2 = transpose_residual_block(
                    out_2, channels=channels, spectral_norm=spectral_norm, to_down=False, scope=f'md_resbk_{i}'
                )

        flat_out = tf.reshape(out_2, shape=[-1, tf.math.reduce_prod(shape_list(out_2)[1:])])
        dense_1 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o1')(flat_out)
        dense_2 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o2')(dense_1)
        mix_1 = dense_1 + dense_2*0.25  # noqa: E226

        dense_3 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o3')(mix_1)
        mix_2 = dense_3 + mix_1*0.25  # noqa: E226

        dense_4 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_of2')(mix_2)
        dense_5 = tf.keras.layers.Dense(out_classes * mini_beat_per_seg, activation='tanh',
                                        name='mdl_nn_mlp_of3')(dense_4)

        out = dense_5*70 + 50  # noqa: E226
        out = tf.reshape(out, shape=[-1, out_classes, mini_beat_per_seg])

        return tf.keras.Model(inputs=inp_wrap, outputs=out)