Beispiel #1
0
 def __model(self, tf_mix):
          # define model flow
         # stft
         stft_module = STFT_Module(
                 frame_length = self.stft_params["frame_length"], 
                 frame_step= self.stft_params["frame_step"], 
                 fft_length = self.stft_params["fft_length"],
                 epsilon = self.epsilon,
                 pad_end = self.stft_params["pad_end"]
         )
         
         mr_stft_module = STFT_Module(
                 frame_length = self.mr_stft_params["frame_length"], 
                 frame_step= self.mr_stft_params["frame_step"], 
                 fft_length = self.mr_stft_params["fft_length"],
                 epsilon = self.epsilon,
                 pad_end = self.stft_params["pad_end"]
         )
         
         
         # mix data transform
         tf_spec_mix = stft_module.STFT(tf_mix)
         tf_phase_mix = tf.sign(tf_spec_mix)
         tf_phase_mix = self.expand_channel(tf_phase_mix)
         tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize =False)
         tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
         tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
         tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
         tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)
         
          #mr mix data transform
         tf_mr_spec_mix = mr_stft_module.STFT(tf_mix)
         tf_mr_spec_mix = tf_mr_spec_mix[:, 1:513, :]
         tf_mr_amp_spec_mix = stft_module.to_amp_spec(tf_mr_spec_mix, normalize =False)
         tf_mr_mag_spec_mix = tf.log(tf_mr_amp_spec_mix + self.epsilon)
         tf_mr_mag_spec_mix = tf.expand_dims(tf_mr_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
         tf_mr_f_256_mag_spec_mix = tf_mr_mag_spec_mix[:, :, :256]
                          
         mini_u_net_ver2 = mini_UNet_ver2(
                 input_shape = (
                         tf_f_512_mag_spec_mix.shape[1:]
                 ),
                 mr1_input_shape = (
                        tf_mr_f_256_mag_spec_mix.shape[1:]
                 )
         )
     
         tf_est_masks,_,_,_,_,_ = mini_u_net_ver2(tf_f_512_mag_spec_mix, tf_mr_f_256_mag_spec_mix)
         
         #F: 512  → 513
         zero_pad = tf.zeros_like(tf_mag_spec_mix)
         zero_pad = tf.expand_dims(zero_pad[:,:,1,:], -1)
         tf_est_masks = tf.concat( [tf_est_masks, zero_pad], 2)
         tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
         tf_est_source_spec = tf.math.multiply(tf.complex(tf_est_spec, 0.), tf_phase_mix)
         tf_est_source_spec = tf.squeeze(tf_est_source_spec, axis=-1)                
         est_source = stft_module.ISTFT(tf_est_source_spec)
         return est_source
        def __model(self, tf_mix):
                 # define model flow
                # stft
                stft_module = STFT_Module(
                        frame_length = self.stft_params["frame_length"], 
                        frame_step= self.stft_params["frame_step"], 
                        fft_length = self.stft_params["fft_length"],
                        epsilon = self.epsilon,
                        pad_end = self.stft_params["pad_end"]
                )
                
                # mix data transform
                tf_spec_mix = stft_module.STFT(tf_mix)
                tf_phase_mix = tf.sign(tf_spec_mix)
                tf_phase_mix = self.expand_channel(tf_phase_mix)

                tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize =False)
                tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
                tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
                tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
                tf_mag_spec_mix = tf_mag_spec_mix[:,:,:1024,:]

                u_net = UNet(
                        input_shape =(
                             tf_mag_spec_mix.shape[1:]
                        )
                )
            
                tf_est_masks = u_net(tf_mag_spec_mix)
                
                zero_pad = tf.zeros_like(tf_mag_spec_mix)
                zero_pad = tf.expand_dims(zero_pad[:,:,1,:], -1)
                tf_est_masks = tf.concat([zero_pad, tf_est_masks], 2)
                tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
                tf_est_source_spec = tf.math.multiply(tf.complex(tf_est_spec, 0.), tf_phase_mix)
                tf_est_source_spec = tf.squeeze(tf_est_source_spec, axis=-1)                
                est_source = stft_module.ISTFT(tf_est_source_spec)
                return est_source
    def __model(self, tf_mix):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            pad_end=self.stft_params["pad_end"],
            epsilon=self.epsilon)

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        print("spec mix", tf_spec_mix.dtype)
        tf_spec_mix = stft_module.to_T_256(
            tf_spec_mix)  # cut time dimension to 256 for u-net architecture
        tf_phase_mix = tf.sign(tf_spec_mix)
        tf_phase_mix = self.expand_channel(tf_phase_mix)
        #             tf_mag_spec_mix = stft_module.to_magnitude_spec(tf_spec_mix, normalize=False)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        # target data transform
        #                 tf_spec_target = stft_module.STFT(tf_target)
        #                 tf_spec_target = stft_module.to_T_256(tf_spec_target) # cut time dimensiton to 256 for u-net architecture

        #                 tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target, normalize=False)
        #                 tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        conv_ffn = Conv_FFN(
            input_shape=(tf_f_512_mag_spec_mix.shape[1:]),
            out_dim=512,
            h_dim=512,
        )

        tf_est_masks = conv_ffn(tf_f_512_mag_spec_mix)
        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_est_source_spec = tf.math.multiply(tf.complex(tf_est_spec, 0.),
                                              tf_phase_mix)
        tf_est_source_spec = tf.squeeze(tf_est_source_spec, axis=-1)
        est_source = stft_module.ISTFT(tf_est_source_spec)
        return est_source
Beispiel #4
0
        def __model(self, tf_mix, tf_target, tf_lr):
                 # define model flow
                # stft
                stft_module = STFT_Module(
                        frame_length = self.stft_params["frame_length"], 
                        frame_step= self.stft_params["frame_step"], 
                        fft_length = self.stft_params["fft_length"],
                        epsilon = self.epsilon,
                        pad_end = self.stft_params["pad_end"]
                )
                
                
                # mix data transform
                tf_spec_mix = stft_module.STFT(tf_mix)
                
#                 tf_mag_spec_mix = stft_module.to_magnitude_spec(tf_spec_mix, normalize=False)
                tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize =False)
                tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
                tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
                tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
                tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)
                
                # target data transform
                tf_spec_target = stft_module.STFT(tf_target)             
                tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target, normalize=False)
                tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)
                 
                u_net = UNet(
                        input_shape =(
                                tf_f_512_mag_spec_mix.shape[1:]
                        )
                )
            
                tf_est_masks = u_net(tf_f_512_mag_spec_mix)
                
                #F: 512  → 513
                zero_pad = tf.zeros_like(tf_mag_spec_mix)
                zero_pad = tf.expand_dims(zero_pad[:,:,1,:], -1)
                tf_est_masks = tf.concat( [tf_est_masks, zero_pad], 2)
#                 tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
                tf_ora_masks = Masks.iaf(tf_amp_spec_mix, tf_amp_spec_target,self.epsilon)
                tf_loss = 10 * Loss.mean_square_error(tf_est_masks, tf_ora_masks)
                tf_train_step = Trainer.Adam(tf_loss, tf_lr)
                
                return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix,   tf_spec_mix, tf_est_masks, tf_ora_masks
Beispiel #5
0
    def __model(self, tf_mix):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_phase_mix = tf.sign(tf_spec_mix)
        #                 tf_phase_mix = self.expand_channel(tf_phase_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        #                 tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        #                 tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr 1mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_spec_mix = tf_mr1_spec_mix[:, 1:513, :]
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        #                 tf_mr1_mag_spec_mix = tf.expand_dims(tf_mr1_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        tf_mr1_f_256_mag_spec_mix = tf_mr1_mag_spec_mix[:, :, :256]

        #mr2 mix data transform
        #zero pad to fit stft time length 128
        mr2_zero_pad = tf.zeros_like(tf_mix)
        tf_mr2_mix = tf.concat(
            [mr2_zero_pad[:, :384], tf_mix, mr2_zero_pad[:, :384]], axis=1)
        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mr2_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        #                 tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        tf_mr2_mag_spec_mix = tf_mr2_mag_spec_mix[:, :, :1024]

        ffn_ver2 = FFN_ver2(
            out_dim=512,
            h_dim=512,
        )

        tf_est_masks = ffn_ver2(tf_f_512_mag_spec_mix,
                                tf_mr1_f_256_mag_spec_mix, tf_mr2_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_est_source_spec = tf.math.multiply(tf.complex(tf_est_spec, 0.),
                                              tf_phase_mix)
        tf_est_source_spec = tf.squeeze(tf_est_source_spec, axis=-1)
        est_source = stft_module.ISTFT(tf_est_source_spec)
        return est_source
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])
        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr1 mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_spec_mix = tf_mr1_spec_mix[:, 1:513, :]
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        tf_mr1_mag_spec_mix = tf.expand_dims(
            tf_mr1_mag_spec_mix, -1)  # (Batch, Time, Freq, Channel))
        tf_mr1_f_256_mag_spec_mix = tf_mr1_mag_spec_mix[:, :, :256, :]

        #mr2 mix data transform
        #zero pad to fit stft time length 128
        mr2_zero_pad = tf.zeros_like(tf_mix)
        tf_mr2_mix = tf.concat(
            [mr2_zero_pad[:, :384], tf_mix, mr2_zero_pad[:, :384]], axis=1)
        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mr2_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        tf_mr2_mag_spec_mix = tf_mr2_mag_spec_mix[:, :, :1024, :]

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        mini_u_net_ver3 = mini_UNet_ver3(
            input_shape=(tf_f_512_mag_spec_mix.shape[1:]),
            mr1_input_shape=(tf_mr1_f_256_mag_spec_mix.shape[1:]),
            mr2_input_shape=(tf_mr2_mag_spec_mix.shape[1:]))

        tf_est_masks, _, _, _, _, _ = mini_u_net_ver3(
            tf_f_512_mag_spec_mix, tf_mr1_f_256_mag_spec_mix,
            tf_mr2_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec
Beispiel #7
0
    def __model(self, tf_mix, tf_target, tf_lr):
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        #                 tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        #                 tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        tf_f_512_mag_spec_mix = stft_module.to_F_512(tf_mag_spec_mix)

        #mr1 mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_spec_mix = tf_mr1_spec_mix[:, 1:513, :]
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        #                 tf_mr1_mag_spec_mix = tf.expand_dims(tf_mr1_mag_spec_mix, -1)# (Batch, Time, Freq, Channel))
        tf_mr1_f_256_mag_spec_mix = tf_mr1_mag_spec_mix[:, :, :256]

        #mr2 mix data transform
        #zero pad to fit stft time length 128
        mr2_zero_pad = tf.zeros_like(tf_mix)
        tf_mr2_mix = tf.concat(
            [mr2_zero_pad[:, :384], tf_mix, mr2_zero_pad[:, :384]], axis=1)
        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mr2_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        #                 tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        tf_mr2_mag_spec_mix = tf_mr2_mag_spec_mix[:, :, :1024]

        # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        #                 tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        ffn_ver2 = FFN_ver2(
            out_dim=512,
            h_dim=512,
        )

        tf_est_masks = ffn_ver2(tf_f_512_mag_spec_mix,
                                tf_mr1_f_256_mag_spec_mix, tf_mr2_mag_spec_mix)

        #F: 512  → 513
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1], -1)
        tf_est_masks = tf.concat([tf_est_masks, zero_pad], 2)
        print("est_mask", tf_est_masks.shape)
        print("amp_spec_mix", tf_amp_spec_mix.shape)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec
    def __model(self, tf_mix, tf_target, tf_lr):
        # define model flow
        # stft
        stft_module = STFT_Module(
            frame_length=self.stft_params["frame_length"],
            frame_step=self.stft_params["frame_step"],
            fft_length=self.stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.stft_params["pad_end"])

        mr1_stft_module = STFT_Module(
            frame_length=self.mr1_stft_params["frame_length"],
            frame_step=self.mr1_stft_params["frame_step"],
            fft_length=self.mr1_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr1_stft_params["pad_end"])

        mr2_stft_module = STFT_Module(
            frame_length=self.mr2_stft_params["frame_length"],
            frame_step=self.mr2_stft_params["frame_step"],
            fft_length=self.mr2_stft_params["fft_length"],
            epsilon=self.epsilon,
            pad_end=self.mr2_stft_params["pad_end"])

        #                print(tf_mix.shape)
        #                tf_mix = stft_module.zero_padding(tf_mix, self.sample_len, self.train_data_num)
        #                print(tf_mix.shape)

        # mix data transform
        tf_spec_mix = stft_module.STFT(tf_mix)
        tf_amp_spec_mix = stft_module.to_amp_spec(tf_spec_mix, normalize=False)
        tf_mag_spec_mix = tf.log(tf_amp_spec_mix + self.epsilon)
        tf_mag_spec_mix = tf.expand_dims(tf_mag_spec_mix,
                                         -1)  # (Batch, Time, Freq, Channel))
        tf_amp_spec_mix = tf.expand_dims(tf_amp_spec_mix, -1)
        #mr1 mix data transform
        tf_mr1_spec_mix = mr1_stft_module.STFT(tf_mix)
        tf_mr1_amp_spec_mix = stft_module.to_amp_spec(tf_mr1_spec_mix,
                                                      normalize=False)
        tf_mr1_mag_spec_mix = tf.log(tf_mr1_amp_spec_mix + self.epsilon)
        tf_mr1_mag_spec_mix = tf.expand_dims(
            tf_mr1_mag_spec_mix, -1)  # (Batch, Time, Freq, Channel))

        tf_mr2_spec_mix = mr2_stft_module.STFT(tf_mix)
        tf_mr2_amp_spec_mix = stft_module.to_amp_spec(tf_mr2_spec_mix,
                                                      normalize=False)
        tf_mr2_mag_spec_mix = tf.log(tf_mr2_amp_spec_mix + self.epsilon)
        tf_mr2_mag_spec_mix = tf.expand_dims(tf_mr2_mag_spec_mix, -1)
        #                # target data transform
        tf_spec_target = stft_module.STFT(tf_target)
        tf_amp_spec_target = stft_module.to_amp_spec(tf_spec_target,
                                                     normalize=False)
        tf_amp_spec_target = tf.expand_dims(tf_amp_spec_target, -1)

        tf_input_spec = tf.concat(
            [tf_mag_spec_mix, tf_mr1_mag_spec_mix, tf_mr2_mag_spec_mix], 3)
        print(tf_input_spec.shape)
        tf_input_spec = tf_input_spec[:, :, :1024, :]
        print(tf_input_spec.shape)
        u_net = UNet(input_shape=tf_input_spec.shape[1:])

        tf_est_masks = u_net(tf_input_spec)
        zero_pad = tf.zeros_like(tf_mag_spec_mix)
        zero_pad = tf.expand_dims(zero_pad[:, :, 1, :], -1)
        tf_est_masks = tf.concat([zero_pad, tf_est_masks], 2)
        tf_est_spec = tf.math.multiply(tf_est_masks, tf_amp_spec_mix)
        tf_loss = 10 * Loss.mean_square_error(tf_est_spec, tf_amp_spec_target)
        tf_train_step = Trainer.Adam(tf_loss, tf_lr)

        return tf_train_step, tf_loss, tf_amp_spec_target, tf_mag_spec_mix, tf_spec_mix, tf_est_masks, tf_est_spec