def cnn_attention(x, channels, scope='attention'): with tf.name_scope(scope): ori_shape = shape_list(x) max_pooling = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same") out_f = conv_sa(x, channels // 8, kernel=(1, 1), strides=(1, 1), scope="f_attn_conv") out_f = max_pooling(out_f) out_g = conv_sa(x, channels // 8, kernel=(1, 1), strides=(1, 1), scope='g_attn_conv') out_h = conv_sa(x, channels // 2, kernel=(1, 1), strides=(1, 1), scope='h_attn_conv') out_h = max_pooling(out_h) shape_f = shape_list(out_f) shape_g = shape_list(out_g) flatten_f = tf.reshape(out_f, shape=[-1, shape_f[1] * shape_f[2], shape_f[3]]) flatten_g = tf.reshape(out_g, shape=[-1, shape_g[1] * shape_g[2], shape_g[3]]) attn_out = tf.linalg.matmul(flatten_g, flatten_f, transpose_b=True) attn_matrix = tf.keras.activations.softmax(attn_out) shape_h = shape_list(out_h) flatten_h = tf.reshape(out_h, shape=[-1, shape_h[1] * shape_h[2], shape_h[3]]) attn_out_2 = tf.linalg.matmul(attn_matrix, flatten_h) out = tf.reshape(attn_out_2, shape=[-1, ori_shape[1], ori_shape[2], ori_shape[3] // 2]) out = conv_sa(out, channels, kernel=(1, 1), strides=(1, 1), scope='out_attn_conv') gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) x_att = gamma*out + x # noqa: E226 x_gma = gamma * out return x_att, x_gma
def call(self, inp, encoder_input_emb, chord_change_pred): segment_encodings = self.encode_segment_frequency(inp) segment_encodings_blocked, block_ids = chord_block_compression( segment_encodings, chord_change_pred) segment_encodings_blocked = chord_block_decompression( segment_encodings_blocked, block_ids) segment_encodings_blocked.set_shape( [None, self.n_steps, self.dec_input_emb_size]) decoder_inputs = segment_encodings + segment_encodings_blocked + encoder_input_emb decoder_inputs += positional_encoding( batch_size=shape_list(decoder_inputs)[0], timesteps=self.n_steps, n_units=self.dec_input_emb_size) decoder_inputs_drop = self.dropout(decoder_inputs) layer_weights = tf.nn.softmax(tf.zeros((self.num_attn_blocks))) weighted_hiddens_dec = tf.zeros(shape=shape_list(segment_encodings)) layer_stack = zip(self.attn_layers_1, self.attn_layers_2, self.ff_layers) for idx, (attn_1, attn_2, feed_forward) in enumerate(layer_stack): decoder_inputs_drop = attn_1(q=decoder_inputs_drop, k=decoder_inputs_drop, v=decoder_inputs_drop) decoder_inputs_drop = attn_2(q=decoder_inputs_drop, k=encoder_input_emb, v=encoder_input_emb) decoder_inputs_drop = feed_forward(decoder_inputs_drop) weighted_hiddens_dec += layer_weights[idx] * decoder_inputs_drop logits = self.out_dense(weighted_hiddens_dec) chord_pred = tf.argmax(input=logits, axis=-1, output_type=tf.int32) return logits, chord_pred
def pad_to_multiple_2d(x, block_shape): """Making sure x is a multiple of shape. Parameters ---------- x A [batch, heads, h, w, depth] or [batch, h, w, depth] tensor block_shape A 2D list of integer shapes Returns ------- padded_x A [batch, heads, h, w, depth] or [batch, h, w, depth] tensor """ old_shape = x.get_shape().dims last = old_shape[-1] if len(old_shape) == 4: height_padding = -shape_list(x)[1] % block_shape[0] width_padding = -shape_list(x)[2] % block_shape[1] paddings = [[0, 0], [0, height_padding], [0, width_padding], [0, 0]] elif len(old_shape) == 5: height_padding = -shape_list(x)[2] % block_shape[0] width_padding = -shape_list(x)[3] % block_shape[1] paddings = [[0, 0], [0, 0], [0, height_padding], [0, width_padding], [0, 0]] padded_x = tf.pad(x, paddings) padded_shape = padded_x.get_shape().as_list() padded_shape = padded_shape[:-1] + [last] padded_x.set_shape(padded_shape) return padded_x
def call(self, q, k, v): # pylint: disable=W0221 q_emb = self.q_emb_dense(q) k_emb = self.k_emb_dense(k) v_emb = self.v_emb_dense(v) q_heads = tf.concat(tf.split(q_emb, self.n_heads, 2), 0) k_heads = tf.concat(tf.split(k_emb, self.n_heads, 2), 0) v_heads = tf.concat(tf.split(v_emb, self.n_heads, 2), 0) attn_weights = tf.matmul(q_heads, tf.transpose(k_heads, perm=[0, 2, 1])) if self.relative_position: tk, dk = shape_list(k_heads)[1:] rel_pos_enc_k = relative_positional_encoding( n_steps=tk, n_units=dk, max_dist=self.max_dist) rel_pos_enc_k = tf.matmul(tf.transpose(a=q_heads, perm=[1, 0, 2]), rel_pos_enc_k, transpose_b=True) rel_pos_enc_k = tf.transpose(a=rel_pos_enc_k, perm=[1, 0, 2]) attn_weights += rel_pos_enc_k scaled_attn_weights = attn_weights / shape_list(k_heads)[-1]**0.5 if self.causal: diag_vals = tf.ones_like(scaled_attn_weights[0]) tril_mask = tf.linalg.LinearOperatorLowerTriangular( diag_vals).to_dense() tril_paddings = tf.ones_like(tril_mask) * (-2**32 + 1) tril_masking = lambda x: tf.where(tril_mask == 0, tril_paddings, x) scaled_attn_weights = tf.map_fn(tril_masking, scaled_attn_weights) if self.self_mask: diag = tf.zeros_like(scaled_attn_weights[:, :, 0]) scaled_attn_weights = tf.linalg.set_diag(input=scaled_attn_weights, diagonal=diag) exp_attn_weights = tf.nn.softmax(scaled_attn_weights) exp_attn_weights = self.dropout(exp_attn_weights) outputs = tf.matmul(exp_attn_weights, v_heads) if self.relative_position: tv, dv = shape_list(v_heads)[1:] rel_pos_enc_v = relative_positional_encoding( n_steps=tv, n_units=dv, max_dist=self.max_dist) rel_pos_enc_v = tf.matmul( tf.transpose(a=exp_attn_weights, perm=[1, 0, 2]), rel_pos_enc_v) rel_pos_enc_v = tf.transpose(a=rel_pos_enc_v, perm=[1, 0, 2]) outputs += rel_pos_enc_v outputs = tf.concat(tf.split(outputs, self.n_heads, 0), 2) # Restore shape outputs = self.out_dense(outputs) outputs += q # Residual connection return self.layer_norm(outputs)
def scatter_blocks_2d(x, indices, shape): """scatters blocks from x into shape with indices.""" x_shape = shape_list(x) # [length, batch, heads, dim] x_t = tf.transpose( tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3]) x_t_shape = shape_list(x_t) indices = tf.reshape(indices, [-1, 1]) scattered_x = tf.scatter_nd(indices, x_t, x_t_shape) scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3]) return tf.reshape(scattered_x, shape)
def conv_sa( x, channels, kernel=(4, 4), strides=(2, 2), pad=0, pad_type="zero", spectral_norm=True, scope="conv_0" ): with tf.name_scope(scope): if pad > 0: height = shape_list(x)[1] if height % strides[1] == 0: pad *= 2 else: pad = max(kernel[1] - (height % strides[1]), 0) pad_top = pad // 2 pad_bottom = pad - pad_top pad_left = pad // 2 pad_right = pad - pad_left if pad_type == 'zero': x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) if pad_type == 'reflect': x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') if spectral_norm: return ConvSN2D(channels, kernel_size=kernel, strides=strides, scope=f"{scope}_SN")(x) return tf.keras.layers.Conv2D( channels, kernel_size=kernel, strides=strides, padding='same', name=f"{scope}_conv" )(x)
def call(self, inp): # output dim: [batch_size*n_steps, tonal_size, segment_width] inp_reshape = tf.reshape( inp, shape=[-1, self.freq_size, self.segment_width]) # output dim: [batch_size*n_steps, segment_width, tonal_size] inp_permute = tf.transpose(a=inp_reshape, perm=[0, 2, 1]) inp_permute += positional_encoding( batch_size=shape_list(inp_permute)[0], timesteps=self.segment_width, n_units=self.freq_size) * 0.01 + 0.01 attn_output = self.attn_layer(q=inp_permute, k=inp_permute, v=inp_permute) forward_output = self.feed_forward(attn_output) # restore shape outputs = tf.transpose(a=forward_output, perm=[0, 2, 1]) outputs = tf.reshape( outputs, shape=[-1, self.n_steps, self.freq_size * self.segment_width]) outputs = self.dropout(outputs) outputs = self.dense(outputs) return self.layer_norm(outputs)
def gather_blocks_2d(x, indices): """Gathers flattened blocks from x.""" x_shape = shape_list(x) x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])]) # [length, batch, heads, dim] x_t = tf.transpose(x, [2, 0, 1, 3]) x_new = tf.gather(x_t, indices) # returns [batch, heads, num_blocks, block_length ** 2, dim] return tf.transpose(x_new, [2, 3, 0, 1, 4])
def gather_indices_2d(x, block_shape, block_stride): """Getting gather indices.""" # making an identity matrix kernel kernel = tf.eye(block_shape[0] * block_shape[1]) kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1]) # making indices [1, h, w, 1] to appy convs x_shape = shape_list(x) indices = tf.range(x_shape[2] * x_shape[3]) indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1]) indices = tf.nn.conv2d(tf.cast(indices, tf.float32), kernel, strides=[1, block_stride[0], block_stride[1], 1], padding="VALID") # making indices [num_blocks, dim] to gather dims = shape_list(indices)[:3] if all([isinstance(dim, int) for dim in dims]): num_blocks = functools.reduce(operator.mul, dims, 1) else: num_blocks = tf.reduce_prod(dims) indices = tf.reshape(indices, [num_blocks, -1]) return tf.cast(indices, tf.int32)
def call(self, inp, slope=1): segment_encodings = self.encode_segment_time(inp) segment_encodings += positional_encoding( batch_size=shape_list(segment_encodings)[0], timesteps=self.n_steps, n_units=self.enc_input_emb_size) segment_encodings = self.dropout(segment_encodings) weight = tf.nn.softmax(self.layer_weights) weighted_hidden_enc = tf.zeros(shape=shape_list(segment_encodings)) for idx, (attn_layer, feed_forward) in enumerate( zip(self.attn_layers, self.ff_layers)): segment_encodings = attn_layer(q=segment_encodings, k=segment_encodings, v=segment_encodings) segment_encodings = feed_forward(segment_encodings) weighted_hidden_enc += weight[idx] * segment_encodings chord_change_logits = tf.squeeze(self.logit_dense(weighted_hidden_enc)) chord_change_prob = tf.sigmoid(slope * chord_change_logits) chord_change_pred = binary_round(chord_change_prob, cast_to_int=True) return weighted_hidden_enc, chord_change_logits, chord_change_pred
def combine_last_two_dimensions(x): """Reshape x so that the last two dimension become one. Parameters ---------- x A Tensor with shape [..., a, b] Returns ------- y A Tensor with shape [..., ab] """ x_shape = shape_list(x) a, b = x_shape[-2:] return tf.reshape(x, x_shape[:-2] + [a * b]) # noqa: E226
def call(self, inp): inp_reshape = tf.reshape(inp, [-1, self.freq_size, self.segment_width]) inp_reshape += positional_encoding( batch_size=shape_list(inp_reshape)[0], timesteps=self.freq_size, n_units=self.segment_width) * 0.01 + 0.01 attn_output = self.attn_layer(q=inp_reshape, k=inp_reshape, v=inp_reshape) forward_output = self.feed_forward(attn_output) # restore shape outputs = tf.reshape( forward_output, shape=[-1, self.n_steps, self.freq_size * self.segment_width]) outputs = self.dropout(outputs) outputs = self.out_dense(outputs) return self.layer_norm(outputs)
def transpose_residual_block(x, channels, to_down=True, spectral_norm=True, scope='transblock'): with tf.name_scope(scope): init_channel = shape_list(x)[-1] with tf.name_scope('res1'): out = tf.keras.layers.ELU()(x) out = conv_sa( out, channels, kernel=(3, 3), strides=(1, 1), pad=1, pad_type='reflect', spectral_norm=spectral_norm, scope=f"{scope}_res1" ) with tf.name_scope('res2'): out = tf.keras.layers.ELU()(out) out = conv_sa( out, channels, kernel=(3, 3), strides=(1, 1), pad=1, pad_type='reflect', spectral_norm=spectral_norm, scope=f"{scope}_res2" ) if to_down: out = down_sample(out) if to_down or init_channel != channels: with tf.name_scope('shortcut'): x = conv_sa( x, channels, kernel=(1, 1), strides=(1, 1), spectral_norm=spectral_norm, scope=f"{scope}_shortcut" ) if to_down: x = down_sample(x) return out + x
def split_last_dimension(x, n): """Reshape x so that the last dimension becomes two dimensions. The first of these two dimensions is n. Parameters ---------- x A Tensor with shape [..., m] n: int An integer. Returns ------- y A Tensor with shape [..., n, m/n] """ x_shape = shape_list(x) m = x_shape[-1] if isinstance(m, int) and isinstance(n, int): assert m % n == 0 return tf.reshape(x, x_shape[:-1] + [n, m // n])
def build(self, input_shape): self.layer.build(input_shape) self.w = self.layer.kernel self.w_shape = shape_list(self.w) self.v = self.add_weight( shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]), initializer=tf.initializers.TruncatedNormal(stddev=0.02), trainable=False, name='sn_v', dtype=tf.float32 ) self.u = self.add_weight( shape=(1, self.w_shape[-1]), initializer=tf.initializers.TruncatedNormal(stddev=0.02), trainable=False, name='sn_u', dtype=tf.float32 ) super().build()
def local_attention_2d(q, k, v, query_shape=(8, 16), memory_flange=(8, 16), name=None): """Strided block local self-attention. The 2-D sequence is divided into 2-D blocks of shape query_shape. Attention for a given query position can only see memory positions less than or equal to the query position. The memory positions are the corresponding block with memory_flange many positions to add to the height and width of the block (namely, left, top, and right). Parameters ---------- q A tensor with shape [batch, heads, h, w, depth_k] k A tensor with shape [batch, heads, h, w, depth_k] v A tensor with shape [batch, heads, h, w, depth_v]. In the current implementation, depth_v must be equal to depth_k. query_shape: tuple An tuple indicating the height and width of each query block. memory_flange: tuple An integer indicating how much to look in height and width from each query block. name: str An optional string Returns ------- y A Tensor of shape [batch, heads, h, w, depth_v] """ with tf.compat.v1.variable_scope(name, default_name="local_self_attention_2d", values=[q, k, v]): v_shape = shape_list(v) # Pad query, key, value to ensure multiple of corresponding lengths. q = pad_to_multiple_2d(q, query_shape) k = pad_to_multiple_2d(k, query_shape) v = pad_to_multiple_2d(v, query_shape) paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]], [memory_flange[0], memory_flange[1]], [0, 0]] k = tf.pad(k, paddings) v = tf.pad(v, paddings) # Set up query blocks. q_indices = gather_indices_2d(q, query_shape, query_shape) q_new = gather_blocks_2d(q, q_indices) # Set up key and value blocks. memory_shape = (query_shape[0] + 2 * memory_flange[0], query_shape[1] + 2 * memory_flange[1]) k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape) k_new = gather_blocks_2d(k, k_and_v_indices) v_new = gather_blocks_2d(v, k_and_v_indices) attention_bias = tf.expand_dims( tf.compat.v1.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) output = dot_product_attention(q_new, k_new, v_new, attention_bias, dropout_rate=0.0, name="local_2d") # Put representations back into original shapes. padded_q_shape = shape_list(q) output = scatter_blocks_2d(output, q_indices, padded_q_shape) # Remove the padding if introduced. output = tf.slice(output, [0, 0, 0, 0, 0], [-1, -1, v_shape[2], v_shape[3], -1]) return output
def reshape_range(tensor, i, j, shape): """Reshapes a tensor between dimensions i and j.""" t_shape = shape_list(tensor) target_shape = t_shape[:i] + shape + t_shape[j:] return tf.reshape(tensor, target_shape)
def drum_model(out_classes, mini_beat_per_seg, res_block_num=3, channels=64, spectral_norm=True): """Get the drum transcription model. Constructs the drum transcription model instance for training/inference. Parameters ---------- out_classes: int Total output classes, refering to classes of drum types. Currently there are 13 pre-defined drum percussions. mini_beat_per_seg: int Number of mini beats in a segment. Can be understood as the range of time to be considered for training. res_block_num: int Number of residual blocks. Returns ------- model: tf.keras.Model A tensorflow keras model instance. """ with tf.name_scope('transcription_model'): inp_wrap = tf.keras.Input(shape=(120, 120, mini_beat_per_seg), name="input_tensor") input_tensor = inp_wrap * tf.constant(100) padded_input = tf.pad(input_tensor, [[0, 0], [0, 0], [1, 0], [0, 0]], name='tf_diff2_pady')[:, :, :-1] pad_out = input_tensor - padded_input pad_out = tf.concat([tf.zeros_like(pad_out)[:, :, :1], pad_out[:, :, 1:]], axis=2) pad_out = tf.concat([input_tensor, pad_out], axis=-1, name='tf_diff2_concat') out = residual_block(pad_out, channels=channels, spectral_norm=spectral_norm, scope='init_resbk') out = transpose_residual_block(out, channels=channels, spectral_norm=spectral_norm, scope='fd_resbk') x_att, _ = cnn_attention(out, channels=channels, scope='self_attn') for i in range(res_block_num): if i == 0: # first res layer out_2 = transpose_residual_block( x_att, channels=channels, spectral_norm=spectral_norm, scope=f"md_resbk_{i}" ) elif i != res_block_num - 1: # middle res layer out_2 = transpose_residual_block( out_2, channels=channels, spectral_norm=spectral_norm, scope=f"md_resbk_{i}" ) else: # last res layer out_2 = transpose_residual_block( out_2, channels=channels, spectral_norm=spectral_norm, to_down=False, scope=f'md_resbk_{i}' ) flat_out = tf.reshape(out_2, shape=[-1, tf.math.reduce_prod(shape_list(out_2)[1:])]) dense_1 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o1')(flat_out) dense_2 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o2')(dense_1) mix_1 = dense_1 + dense_2*0.25 # noqa: E226 dense_3 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_o3')(mix_1) mix_2 = dense_3 + mix_1*0.25 # noqa: E226 dense_4 = tf.keras.layers.Dense(2**10, activation='elu', name='mdl_nn_mlp_of2')(mix_2) dense_5 = tf.keras.layers.Dense(out_classes * mini_beat_per_seg, activation='tanh', name='mdl_nn_mlp_of3')(dense_4) out = dense_5*70 + 50 # noqa: E226 out = tf.reshape(out, shape=[-1, out_classes, mini_beat_per_seg]) return tf.keras.Model(inputs=inp_wrap, outputs=out)