Example #1
0
def model_fn(features, labels, mode, params):

    x = tf.expand_dims(features['combined'], -1)
    y = tf.expand_dims(features['y'], -1)
    partitioned_x = tf_featurization.pad_and_partition(x, partition_size)
    partitioned_y = tf_featurization.pad_and_partition(y, partition_size)
    model = unet.Model(partitioned_x, channels_interval = 36)
    l2_loss, snr = enhancement.loss.snr(model.logits, partitioned_y)
    sdr = enhancement.loss.sdr(model.logits, partitioned_y)
    mae = tf.losses.absolute_difference
    mae_loss = mae(labels = partitioned_y, predictions = model.logits)
    loss = mae_loss

    tf.identity(loss, 'total_loss')
    tf.summary.scalar('total_loss', loss)

    tf.summary.scalar('snr', snr)
    tf.summary.scalar('sdr', sdr)

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.constant(value = init_lr, shape = [], dtype = tf.float32)
    learning_rate = tf.train.polynomial_decay(
        learning_rate,
        global_step,
        epochs,
        end_learning_rate = 1e-6,
        power = 1.0,
        cycle = False,
    )
    tf.summary.scalar('learning_rate', learning_rate)

    if mode == tf.estimator.ModeKeys.TRAIN:

        optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

        train_op = optimizer.minimize(loss, global_step = global_step)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode = mode, loss = loss, train_op = train_op
        )

    elif mode == tf.estimator.ModeKeys.EVAL:

        estimator_spec = tf.estimator.EstimatorSpec(
            mode = tf.estimator.ModeKeys.EVAL, loss = loss
        )

    return estimator_spec
Example #2
0
    def __init__(
        self,
        inputs,
        sources=4,
        audio_channels=1,
        channels=64,
        depth=6,
        rewrite=True,
        use_glu=True,
        rescale=0.1,
        kernel_size=8,
        stride=4,
        growth=2.0,
        lstm_layers=2,
        context=3,
        partition_length=44100 * 2,
        norm_after_partition=False,
        output_shape_same_as_input=False,
        logging=False,
        kernel_initializer=ConvScaling,
    ):
        self.audio_channels = audio_channels
        self.sources = sources
        self.kernel_size = kernel_size
        self.context = context
        self.stride = stride
        self.depth = depth
        self.channels = channels
        self.partition_length = partition_length

        if use_glu:
            activation = glu
            ch_scale = 2
        else:
            activation = tf.nn.relu
            ch_scale = 1

        in_channels = audio_channels

        self.encoder, self.decoder = [], []
        for index in range(depth):
            encoder = tf.keras.Sequential()
            encoder.add(
                tf.keras.layers.Conv1D(
                    channels,
                    kernel_size,
                    stride,
                    activation=tf.nn.relu,
                    kernel_initializer=kernel_initializer,
                ))
            if rewrite:
                encoder.add(
                    tf.keras.layers.Conv1D(
                        ch_scale * channels,
                        1,
                        activation=activation,
                        kernel_initializer=kernel_initializer,
                    ))
            self.encoder.append(encoder)

            if index > 0:
                out_channels = in_channels
            else:
                out_channels = sources * audio_channels

            decoder = tf.keras.Sequential()
            if rewrite:
                decoder.add(
                    tf.keras.layers.Conv1D(
                        ch_scale * channels,
                        context,
                        activation=activation,
                        kernel_initializer=kernel_initializer,
                    ))

            if index > 0:
                a = tf.nn.relu
            else:
                a = None

            decoder.add(
                Conv1DTranspose(
                    out_channels,
                    kernel_size,
                    stride,
                    activation=a,
                    kernel_initializer=kernel_initializer,
                ))
            self.decoder.insert(0, decoder)
            in_channels = channels
            channels = int(growth * channels)

        channels = in_channels
        if partition_length:
            partitioned = pad_and_partition(inputs, self.partition_length)
            if norm_after_partition:
                mean = tf.reduce_mean(partitioned, axis=0)
                std = tf.math.reduce_std(partitioned, axis=0)
                partitioned = (partitioned - mean) / std

        valid_length = self.valid_length(partitioned.shape.as_list()[1])
        delta = valid_length - self.partition_length
        padded = tf.pad(
            partitioned,
            [[0, 0], [delta // 2, delta - delta // 2], [0, 0]],
            'CONSTANT',
        )

        if lstm_layers:
            self.lstm = BLSTM(channels, lstm_layers)
        else:
            self.lstm = None

        x = padded
        saved = [x]
        for encode in self.encoder:
            if logging:
                print(x)
            x = encode(x)
            saved.append(x)

        if logging:
            print('x', x)
        if self.lstm:
            x = self.lstm(x)

        for decode in self.decoder:
            if logging:
                print(x)
            skip = center_trim(saved.pop(-1), x)
            x = x + skip
            x = decode(x)

        if logging:
            print('x', x)

        self.logits = x
        self.logits = tf.reshape(self.logits, (-1, self.sources))
        if output_shape_same_as_input:
            self.logits = self.logits[:tf.shape(inputs)[0]]
    def __init__(self, X, Y, frame_length=4096, frame_step=1024):
        def get_stft(X):
            return tf.signal.stft(
                X,
                frame_length,
                frame_step,
                window_fn=lambda frame_length, dtype:
                (hann_window(frame_length, periodic=True, dtype=dtype)),
                pad_end=True,
            )

        stft_X = get_stft(X)
        stft_Y = get_stft(Y)
        mag_X = tf.abs(stft_X)
        mag_Y = tf.abs(stft_Y)

        angle_X = tf.math.imag(stft_X)
        angle_Y = tf.math.imag(stft_Y)

        partitioned_mag_X = tf_featurization.pad_and_partition(mag_X, 512)
        partitioned_angle_X = tf_featurization.pad_and_partition(angle_X, 512)
        params = {'conv_n_filters': [32 * (2**i) for i in range(6)]}

        with tf.variable_scope('model_mag'):
            mix_mag = tf.expand_dims(partitioned_mag_X, 3)[:, :, :-1, :]
            mix_mag_logits = unet.Model(
                mix_mag,
                output_mask_logit=True,
                dropout=0.0,
                training=True,
                params=params,
            ).logits
            mix_mag_logits = tf.squeeze(mix_mag_logits, 3)
            mix_mag_logits = tf.pad(mix_mag_logits, [(0, 0), (0, 0), (0, 1)],
                                    mode='CONSTANT')
            mix_mag_logits = tf.nn.relu(mix_mag_logits)

        with tf.variable_scope('model_angle'):
            mix_angle = tf.expand_dims(partitioned_angle_X, 3)[:, :, :-1, :]
            mix_angle_logits = unet.Model(
                mix_angle,
                output_mask_logit=True,
                dropout=0.0,
                training=True,
                params=params,
            ).logits
            mix_angle_logits = tf.squeeze(mix_angle_logits, 3)
            mix_angle_logits = tf.pad(mix_angle_logits, [(0, 0), (0, 0),
                                                         (0, 1)],
                                      mode='CONSTANT')

        partitioned_mag_Y = tf_featurization.pad_and_partition(mag_Y, 512)
        partitioned_angle_Y = tf_featurization.pad_and_partition(angle_Y, 512)

        self.mag_l1 = tf.reduce_mean(tf.abs(partitioned_mag_Y -
                                            mix_mag_logits))
        self.angle_l1 = tf.reduce_mean(
            tf.abs(partitioned_angle_Y - mix_angle_logits))
        self.cost = self.mag_l1 + self.angle_l1

        def get_original_shape(D, stft):
            instrument_mask = D

            old_shape = tf.shape(instrument_mask)
            new_shape = tf.concat(
                [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
            instrument_mask = tf.reshape(instrument_mask, new_shape)
            instrument_mask = instrument_mask[:tf.shape(stft)[0]]
            return instrument_mask

        _mag = get_original_shape(tf.expand_dims(mix_mag_logits, -1), stft_X)
        _angle = get_original_shape(tf.expand_dims(mix_angle_logits, -1),
                                    stft_X)

        stft = tf.multiply(tf.complex(_mag, 0.0),
                           tf.exp(tf.complex(0.0, _angle)))

        inverse_stft_X = inverse_stft(
            stft[:, :, 0],
            frame_length,
            frame_step,
            window_fn=lambda frame_length, dtype:
            (hann_window(frame_length, periodic=True, dtype=dtype)),
        )
Example #4
0
    def __init__(
        self,
        inputs,
        y=None,
        chin=1,
        chout=1,
        hidden=48,
        depth=5,
        use_glu=True,
        kernel_size=8,
        stride=4,
        causal=True,
        resample=4,
        growth=2,
        max_hidden=10000,
        normalize=True,
        rescale=0.1,
        floor=1e-3,
        lstm_layers=2,
        partition_length=44100 * 2,
        norm_after_partition=False,
        logging=False,
        kernel_initializer=ConvScaling,
    ):
        self.depth = depth
        self.kernel_size = kernel_size
        self.stride = stride
        self.causal = causal
        self.floor = floor
        self.resample = resample
        self.normalize = normalize

        self.chin = chin
        self.chout = chout
        self.hidden = hidden
        self.partition_length = partition_length

        if use_glu:
            activation = glu
            ch_scale = 2
        else:
            activation = tf.nn.relu
            ch_scale = 1

        self.encoder, self.decoder = [], []
        for index in range(depth):
            encoder = tf.keras.Sequential()
            encoder.add(
                tf.keras.layers.Conv1D(
                    hidden,
                    kernel_size,
                    stride,
                    activation=tf.nn.relu,
                    kernel_initializer=kernel_initializer,
                ))
            encoder.add(
                tf.keras.layers.Conv1D(
                    ch_scale * hidden,
                    1,
                    activation=activation,
                    kernel_initializer=kernel_initializer,
                ))
            self.encoder.append(encoder)

            decoder = tf.keras.Sequential()
            decoder.add(
                tf.keras.layers.Conv1D(
                    ch_scale * hidden,
                    1,
                    activation=activation,
                    kernel_initializer=kernel_initializer,
                ))
            if index > 0:
                a = tf.nn.relu
            else:
                a = None
            decoder.add(
                Conv1DTranspose(
                    chout,
                    kernel_size,
                    stride,
                    activation=a,
                    kernel_initializer=kernel_initializer,
                ))
            self.decoder.insert(0, decoder)
            chout = hidden
            chin = hidden
            hidden = min(int(growth * hidden), max_hidden)

        self.lstm = BLSTM(chin, bi=not causal)

        if self.normalize:
            mono = tf.reduce_mean(inputs, axis=1, keepdims=True)
            self.std = tf.math.reduce_std(mono, axis=0, keepdims=True)
            inputs = inputs / (self.floor + self.std)
        else:
            self.std = 1.0

        partitioned = pad_and_partition(inputs, self.partition_length)
        if norm_after_partition:
            mean = tf.reduce_mean(partitioned, axis=0)
            std = tf.math.reduce_std(partitioned, axis=0)
            partitioned = (partitioned - mean) / std

        valid_length = self.valid_length(self.partition_length)
        delta = valid_length - self.partition_length
        padded = tf.pad(partitioned, [[0, 0], [0, delta], [0, 0]], 'CONSTANT')
        x = padded
        if logging:
            print(x)
        if self.resample == 2:
            x = upsample2(x)
        elif self.resample == 4:
            x = upsample2(x)
            x = upsample2(x)
        if logging:
            print(x)
        skips = []
        for encode in self.encoder:
            if logging:
                print(x)
            x = encode(x)
            skips.append(x)
        if logging:
            print('x', x)
        x = self.lstm(x)
        for decode in self.decoder:
            if logging:
                print(x)
            skip = skips.pop(-1)
            x = x + skip[:, :tf.shape(x)[1]]
            x = decode(x)
        if self.resample == 2:
            x = downsample2(x)
        elif self.resample == 4:
            x = downsample2(x)
            x = downsample2(x)

        if logging:
            print('x', x)

        self.logits = x
        self.logits = tf.reshape(self.logits, (-1, self.chout))
        if y is not None:
            self.logits = self.logits[:tf.shape(y)[0]]
        else:
            self.logits = self.logits[:tf.shape(inputs)[0]]
        self.logits = self.std * self.logits