Ejemplo n.º 1
0
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr_stft_module = STFT_Module(
            frame_length=self.mr_stft_params["frame_length"],
            frame_step=self.mr_stft_params["frame_step"],
            fft_length=self.mr_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr_stft_params["pad_end"])

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr mix data transform
        tf_mr_spec_mix = mr_stft_module.STFT(tf_mix)
        tf_mr_spec_mix = tf_mr_spec_mix[:, 1:513, :]
        tf_mr_amp_spec_mix = stft_module.to_amp_spec(tf_mr_spec_mix,
                                                     normalize=False)
        tf_mr_mag_spec_mix = tf.log(tf_mr_amp_spec_mix + self.epsilon)
        tf_mr_mag_spec_mix = tf.expand_dims(
            tf_mr_mag_spec_mix, -1)  # (Batch, Time, Freq, Channel))
        tf_mr_f_256_mag_spec_mix = tf_mr_mag_spec_mix[:, :, :256]

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        mr_u_net_ver2 = MRUNet_ver2(
            input_shape=(tf_f_512_mag_spec_mix.shape[1:]),
            mr_input_shape=(tf_mr_f_256_mag_spec_mix.shape[1:]))

        tf_est_masks = mr_u_net_ver2(tf_f_512_mag_spec_mix,
                                     tf_mr_f_256_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)

        tf_ora_masks = Masks.iaf(tf_amp_spec_mix, tf_amp_spec_target,
                                 self.epsilon)
        tf_loss = 10 * Loss.mean_square_error(tf_est_masks, tf_ora_masks)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_ora_masks
Ejemplo n.º 2
0
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])
        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr2 mix data transform
        #zero pad to fit stft time length 128
        mr2_zero_pad = tf.zeros_like(tf_mix)
        tf_mr2_mix = tf.concat(
            [mr2_zero_pad[:, :384], tf_mix, mr2_zero_pad[:, :384]], axis=1)
        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mr2_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        tf_mr2_mag_spec_mix = tf_mr2_mag_spec_mix[:, :, :1024, :]

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        mini_u_net_ver4 = mini_UNet_ver4(
            input_shape=(tf_f_512_mag_spec_mix.shape[1:]),
            mr2_input_shape=(tf_mr2_mag_spec_mix.shape[1:]))

        tf_est_masks, _, _, _, _, _ = mini_u_net_ver4(tf_f_512_mag_spec_mix,
                                                      tf_mr2_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec
Ejemplo n.º 3
0
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)

        #             tf_mag_spec_mix = stft_module.to_magnitude_spec(tf_spec_mix, normalize=False)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        conv_ffn = Conv_FFN(
            input_shape=(tf_f_512_mag_spec_mix.shape[1:]),
            out_dim=512,
            h_dim=512,
        )

        tf_est_masks = conv_ffn(tf_f_512_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec
Ejemplo n.º 4
0
    def __model(self, tf_mix, tf_target, tf_lr):
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        #                 tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        #                 tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr1 mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_spec_mix = tf_mr1_spec_mix[:, 1:513, :]
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        #                 tf_mr1_mag_spec_mix = tf.expand_dims(tf_mr1_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        tf_mr1_f_256_mag_spec_mix = tf_mr1_mag_spec_mix[:, :, :256]

        #mr2 mix data transform
        #zero pad to fit stft time length 128
        mr2_zero_pad = tf.zeros_like(tf_mix)
        tf_mr2_mix = tf.concat(
            [mr2_zero_pad[:, :384], tf_mix, mr2_zero_pad[:, :384]], axis=1)
        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mr2_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        #                 tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        tf_mr2_mag_spec_mix = tf_mr2_mag_spec_mix[:, :, :1024]

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        #                 tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        ffn_ver2 = FFN_ver2(
            out_dim=512,
            h_dim=512,
        )

        tf_est_masks = ffn_ver2(tf_f_512_mag_spec_mix,
                                tf_mr1_f_256_mag_spec_mix, tf_mr2_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        print("est_mask", tf_est_masks.shape)
        print("amp_spec_mix", tf_amp_spec_mix.shape)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])

        #                print(tf_mix.shape)
        #                tf_mix = stft_module.zero_padding(tf_mix, self.sample_len, self.train_data_num)
        #                print(tf_mix.shape)

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        #mr1 mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        tf_mr1_mag_spec_mix = tf.expand_dims(
            tf_mr1_mag_spec_mix, -1)  # (Batch, Time, Freq, Channel))

        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        #                # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        tf_input_spec = tf.concat(
            [tf_mag_spec_mix, tf_mr1_mag_spec_mix, tf_mr2_mag_spec_mix], 3)
        print(tf_input_spec.shape)
        tf_input_spec = tf_input_spec[:, :, :1024, :]
        print(tf_input_spec.shape)
        u_net = UNet(input_shape=tf_input_spec.shape[1:])

        tf_est_masks = u_net(tf_input_spec)
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([zero_pad, tf_est_masks], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec