def __init__(self, mixed_input, voice_input, mixed_phase, mixed_audio, voice_audio, background_audio,
                 variant, is_training, learning_rate,
                 data_type, name):
        with tf.variable_scope(name):
            self.mixed_input = mixed_input
            self.voice_input = voice_input
            self.mixed_phase = mixed_phase
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.background_audio = background_audio
            self.variant = variant
            self.is_training = is_training

            if self.variant in ['unet', 'capsunet']:
                self.voice_mask_network = UNet(mixed_input, variant, is_training=is_training, reuse=False, name='voice-mask-unet')
            elif self.variant == 'basic_capsnet':
                self.voice_mask_network = BasicCapsnet(mixed_input, name='SegCaps_CapsNetBasic')
            elif self.variant == 'conv_net':
                self.voice_mask_network = conv_net(mixed_input, is_training=is_training, reuse=None, name='basic_cnn')

            self.voice_mask = self.voice_mask_network.output

            if data_type == 'mag':
                self.gen_voice = self.voice_mask * mixed_input
                self.cost = mf.l1_loss(self.gen_voice, voice_input)

            elif data_type == 'mag_phase':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0], voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(self.gen_voice[:, :, :, 1], voice_input[:, :, :, 1]) * 0.00001
                self.cost = (self.mag_loss + self.phase_loss)/2

            elif data_type == 'mag_phase_diff':
                self.gen_voice_mag = tf.expand_dims(self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0], axis=3)
                self.mag_loss = mf.l1_loss(self.gen_voice_mag[:, :, :, 0], voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(mf.l1_phase_loss(mixed_input[:, :, :, 1], voice_input[:, :, :, 1]),
                                                   self.voice_mask[:, :, :, 1]) * 0.00001
                self.cost = (self.mag_loss + self.phase_loss) / 2
                self.gen_voice_phase = tf.expand_dims(self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1], axis=3)
                self.gen_voice = tf.concat((self.gen_voice_mag, self.gen_voice_phase), axis=3)

            elif data_type == 'real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 0], voice_input[:, :, :, 0])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 1], voice_input[:, :, :, 1])
                self.cost = (self.real_loss + self.imag_loss)/2

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
예제 #2
0
    def __init__(self, mixed_mag, voice_mag, mixed_phase, mixed_audio,
                 voice_audio, variant, is_training, learning_rate, name):
        with tf.variable_scope(name):
            self.mixed_mag = mixed_mag
            self.voice_mag = voice_mag
            self.mixed_phase = mixed_phase
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.variant = variant
            self.is_training = is_training

            if self.variant in ['unet', 'capsunet']:
                self.voice_mask_network = UNet(mixed_mag,
                                               variant,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')
            elif self.variant == 'basic_capsnet':
                self.voice_mask_network = BasicCapsnet(
                    mixed_mag, name='SegCaps_CapsNetBasic')

            self.voice_mask = self.voice_mask_network.output

            self.gen_voice = self.voice_mask * mixed_mag

            self.cost = mf.l1_loss(self.gen_voice, voice_mag)

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
예제 #3
0
    def __init__(self, mixed_spec, voice_spec, mixed_audio, voice_audio, variant, is_training, learning_rate,
                 name='complex_unet_model'):
        with tf.variable_scope(name):
            self.mixed_spec = mixed_spec
            self.voice_spec = voice_spec

            self.input_shape = mixed_spec.get_shape().as_list()
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.variant = variant
            self.is_training = is_training

            self.voice_mask_unet = ComplexUNet(mixed_spec, variant, is_training=is_training, reuse=False,
                                               name='voice-mask-unet')

            self.voice_mask = self.voice_mask_unet.output

            self.gen_voice = self.voice_mask * mixed_spec

            self.cost = mf.l1_loss(self.gen_voice, voice_spec)

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
    def __init__(self,
                 mixed_spec,
                 voice_spec,
                 is_training,
                 reuse=True,
                 name='complex_number_capsnet'):
        """
        input_tensor: Tensor with shape [batch_size, height, width, 2], where the two channels are the real
                      and imaginary parts of the spectrogram
        is_training:  Boolean - should the model be trained on the current input or not
        name:         Model instance name
        """
        with tf.variable_scope(name):
            self.mixed_spec = mixed_spec
            self.voice_spec = voice_spec

            with tf.variable_scope('Primary_Caps'):
                # Reshape layer to be 1 capsule x [filters] atoms
                _, H, W, C = mixed_spec.get_shape()
                input_caps = layers.Reshape(
                    (H.value, W.value, 1, C.value))(mixed_spec)
                self.input_caps = input_caps

            with tf.variable_scope('Conv_Caps'):
                conv_caps = capsule_layers.ConvCapsuleLayer(
                    kernel_size=5,
                    num_capsule=8,
                    num_atoms=2,
                    strides=1,
                    padding='same',
                    routings=1,
                    name='primarycaps')(input_caps)
                self.conv_caps = conv_caps

            #            with tf.variable_scope('Seg_Caps'):
            #                seg_caps = capsule_layers.ConvCapsuleLayer(kernel_size=1, num_capsule=16, num_atoms=2, strides=1, padding='same',
            #                                                           routings=3, name='seg_caps')(conv_caps)
            #                self.seg_caps = seg_caps

            with tf.variable_scope('Reconstruction'):
                reconstruction = capsule_layers.ConvCapsuleLayer(
                    kernel_size=1,
                    num_capsule=1,
                    num_atoms=2,
                    strides=1,
                    padding='same',
                    routings=3,
                    name='seg_caps')(conv_caps)
                reconstruction = layers.Reshape(
                    (H.value, W.value, C.value))(reconstruction)
                self.reconstruction = reconstruction

            self.cost = mf.l1_loss(self.reconstruction, voice_spec)

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=0.0002,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
 def __init__(self, mixed_input, voice_input, mixed_phase, mixed_audio,
              voice_audio, background_audio, data_type, name):
     with tf.variable_scope(name):
         self.mixed_input = mixed_input
         self.voice_input = voice_input
         self.mixed_phase = mixed_phase
         self.mixed_audio = mixed_audio
         self.voice_audio = voice_audio
         self.background_audio = background_audio
         self.data_type = data_type
         if self.data_type == 'mag':
             self.cost = mf.l1_loss(mixed_input, voice_input)
         elif self.data_type == 'mag_phase':
             self.mag_loss = mf.l1_loss(self.mixed_input[:, :, :, 0], voice_input[:, :, :, 0])
             self.phase_loss = mf.l1_phase_loss(self.mixed_input[:, :, :, 1], voice_input[:, :, :, 1]) * 0.00001
             self.cost = (self.mag_loss + self.phase_loss) / 2
         elif self.data_type == 'real_imag':
             self.real_loss = mf.l1_loss(self.mixed_input[:, :, :, 0], voice_input[:, :, :, 0])
             self.imag_loss = mf.l1_loss(self.mixed_input[:, :, :, 1], voice_input[:, :, :, 1])
             self.cost = (self.real_loss + self.imag_loss) / 2
예제 #6
0
    def __init__(self, mixed_input, voice_input, mixed_phase, mixed_audio,
                 voice_audio, background_audio, is_training, learning_rate,
                 data_type, phase_weight, phase_loss_masking,
                 phase_loss_approximation, name):
        with tf.variable_scope(name):
            self.mixed_input = mixed_input
            self.voice_input = voice_input
            self.mixed_phase = mixed_phase
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.background_audio = background_audio
            self.is_training = is_training

            # Initialise the selected model variant
            if data_type == 'complex_to_mag_phase':
                self.voice_mask_network = UNet(mixed_input[:, :, :, 0:2],
                                               data_type,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')
            else:
                self.voice_mask_network = UNet(mixed_input,
                                               data_type,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')

            self.voice_mask = self.voice_mask_network.output

            # Depending on the data_type, setup the loss functions and optimisation
            if data_type == 'mag':
                self.gen_voice = self.voice_mask * mixed_input
                self.cost = mf.l1_loss(self.gen_voice, voice_input)

            elif data_type == 'mag_phase':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :, 1],
                    phase_loss_masking, phase_loss_approximation,
                    self.gen_voice[:, :, :, 0]) * phase_weight
                #self.phase_loss = mf.l1_masked_phase_loss(self.gen_voice[:, :, :, 1], voice_input[:, :, :, 1], self.voice_input[:, :, :, 0]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'mag_phase_diff2':
                self.gen_voice_mag = tf.expand_dims(
                    self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0],
                    axis=3)
                self.mag_loss = mf.l1_loss(self.gen_voice_mag[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    mf.phase_difference(
                        mixed_input[:, :, :, 1],
                        voice_input[:, :, :, 1]), self.voice_mask[:, :, :, 1],
                    phase_loss_masking, phase_loss_approximation,
                    self.gen_voice_mag) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2
                self.gen_voice_phase = tf.expand_dims(
                    self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1],
                    axis=3)
                self.gen_voice = mf.concat(self.gen_voice_mag,
                                           self.gen_voice_phase)

            elif data_type == 'mag_phase_diff':
                self.gen_voice_mag = tf.expand_dims(
                    self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0],
                    axis=3)
                self.gen_voice_phase = tf.expand_dims(
                    self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1],
                    axis=3)
                self.gen_voice = mf.concat(self.gen_voice_mag,
                                           self.gen_voice_phase)
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice_phase, voice_input[:, :, :, 1],
                    phase_loss_masking, phase_loss_approximation,
                    self.gen_voice_mag) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                            voice_input[:, :, :, 0])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.cost = (self.real_loss + self.imag_loss) / 2

            elif data_type == 'mag_real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 2],
                                            voice_input[:, :, :, 2])
                self.cost = (self.mag_loss + self.real_loss +
                             self.imag_loss) / 3

            elif data_type == 'mag_phase2':
                self.mag_mask = self.voice_mask[:, :, :, 0]
                self.phase_mask = tf.angle(
                    tf.complex(self.voice_mask[:, :, :, 1],
                               self.voice_mask[:, :, :, 2]))
                self.voice_mask = mf.concat(
                    tf.expand_dims(self.mag_mask, axis=3),
                    tf.expand_dims(self.phase_mask, axis=3))
                self.gen_voice_mag = self.mag_mask * mixed_input[:, :, :, 0]
                self.gen_voice_phase = self.phase_mask * tf.squeeze(
                    mixed_phase, axis=3)
                self.voice_phase = tf.angle(
                    tf.complex(self.voice_input[:, :, :, 1],
                               self.voice_input[:, :, :, 2]))
                self.gen_voice = mf.concat(
                    tf.expand_dims(self.gen_voice_mag, axis=3),
                    tf.expand_dims(self.gen_voice_phase, axis=3))
                self.mag_loss = mf.l1_loss(self.gen_voice_mag,
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice_phase, self.voice_phase, phase_loss_masking,
                    phase_loss_approximation,
                    self.gen_voice_mag) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'mag_phase_real_imag':
                self.gen_voice = self.voice_mask * mixed_input[:, :, :, 2:4]
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 2])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :, 3],
                    phase_loss_masking, phase_loss_approximation,
                    self.gen_voice[:, :, :, 0]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'complex_to_mag_phase':
                self.gen_voice = self.voice_mask * mixed_input[:, :, :, 2:4]
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 2])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :, 3],
                    phase_loss_masking, phase_loss_approximation,
                    self.gen_voice[:, :, :, 0]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
    def __init__(self, mixed_input, voice_input, mixed_phase, mixed_audio,
                 voice_audio, background_audio, variant, is_training,
                 learning_rate, data_type, phase_weight, name):
        with tf.variable_scope(name):
            self.mixed_input = mixed_input
            self.voice_input = voice_input
            self.mixed_phase = mixed_phase
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.background_audio = background_audio
            self.variant = variant
            self.is_training = is_training

            if self.variant in ['unet', 'capsunet']:
                self.voice_mask_network = UNet(mixed_input,
                                               variant,
                                               data_type,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')
            elif self.variant == 'basic_capsnet':
                self.voice_mask_network = BasicCapsNet(mixed_input,
                                                       name='basic_capsnet')
            elif self.variant == 'basic_convnet':
                self.voice_mask_network = BasicConvNet(mixed_input,
                                                       is_training=is_training,
                                                       reuse=None,
                                                       name='basic_convnet')

            self.voice_mask = self.voice_mask_network.output

            if data_type == 'mag':
                self.gen_voice = self.voice_mask * mixed_input
                self.cost = mf.l1_loss(self.gen_voice, voice_input)

            elif data_type in ['mag_phase']:
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            1]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'mag_phase_diff':
                self.gen_voice_mag = tf.expand_dims(
                    self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0],
                    axis=3)
                self.mag_loss = mf.l1_loss(self.gen_voice_mag[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = mf.l1_phase_loss(
                    mf.phase_difference(mixed_input[:, :, :, 1],
                                        voice_input[:, :, :, 1]),
                    self.voice_mask[:, :, :, 1]) * 0.00001
                self.cost = (self.mag_loss + self.phase_loss) / 2
                self.gen_voice_phase = tf.expand_dims(
                    self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1],
                    axis=3)
                self.gen_voice = mf.concat(self.gen_voice_mag,
                                           self.gen_voice_phase)

            elif data_type == 'real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                            voice_input[:, :, :, 0])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.cost = (self.real_loss + self.imag_loss) / 2

            elif data_type == 'mag_real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 2],
                                            voice_input[:, :, :, 2])
                self.cost = (self.mag_loss + self.real_loss +
                             self.imag_loss) / 3

            elif data_type == 'mag_phase2':
                self.mag_mask = self.voice_mask[:, :, :, 0]
                self.phase_mask = tf.angle(
                    tf.complex(self.voice_mask[:, :, :, 1],
                               self.voice_mask[:, :, :, 2]))
                self.voice_mask = mf.concat(
                    tf.expand_dims(self.mag_mask, axis=3),
                    tf.expand_dims(self.phase_mask, axis=3))
                self.gen_mag = self.mag_mask * mixed_input[:, :, :, 0]
                self.gen_phase = self.phase_mask * tf.squeeze(mixed_phase,
                                                              axis=3)
                self.voice_phase = tf.angle(
                    tf.complex(self.voice_input[:, :, :, 1],
                               self.voice_input[:, :, :, 2]))
                self.gen_voice = mf.concat(
                    tf.expand_dims(self.gen_mag, axis=3),
                    tf.expand_dims(self.gen_phase, axis=3))
                self.mag_loss = mf.l1_loss(self.gen_mag, voice_input[:, :, :,
                                                                     0])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_phase, self.voice_phase) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type in ['mag_phase_real_imag']:
                self.gen_voice = self.voice_mask * mixed_input[:, :, :, 2:4]
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 2])
                self.phase_loss = mf.l1_phase_loss(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            3]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)
예제 #8
0
    def __init__(self, mixed_input, voice_input, mixed_phase, mixed_audio,
                 voice_audio, background_audio, variant, is_training,
                 learning_rate, data_type, phase_weight, phase_loss_function,
                 name):
        with tf.variable_scope(name):
            self.mixed_input = mixed_input
            self.voice_input = voice_input
            self.mixed_phase = mixed_phase
            self.mixed_audio = mixed_audio
            self.voice_audio = voice_audio
            self.background_audio = background_audio
            self.variant = variant
            self.is_training = is_training

            # Set the loss function
            if phase_loss_function == 'l1':
                self.phase_loss_function = mf.l1_loss
            elif phase_loss_function == 'l2':
                self.phase_loss_function = mf.l2_loss
            elif phase_loss_function == 'l1_crcular':
                self.phase_loss_function = mf.l1_phase_loss
            elif phase_loss_function == 'l2_circular':
                self.phase_loss_function = mf.l2_phase_loss

            # Initialise the selected model variant
            if self.variant in ['unet', 'capsunet', 'noconvcapsunet'
                                ] and data_type == 'complex_to_mag_phase':
                self.voice_mask_network = UNet(mixed_input[:, :, :, 0:2],
                                               variant,
                                               data_type,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')
            elif self.variant in ['unet', 'capsunet', 'noconvcapsunet']:
                self.voice_mask_network = UNet(mixed_input,
                                               variant,
                                               data_type,
                                               is_training=is_training,
                                               reuse=False,
                                               name='voice-mask-unet')
            elif self.variant == 'basic_capsnet':
                self.voice_mask_network = BasicCapsNet(mixed_input,
                                                       name='basic_capsnet')
            elif self.variant == 'basic_convnet':
                self.voice_mask_network = BasicConvNet(mixed_input,
                                                       is_training=is_training,
                                                       reuse=None,
                                                       name='basic_convnet')

            self.voice_mask = self.voice_mask_network.output

            # Depending on the data_type, setup the loss functions and optimisation
            if data_type == 'mag':
                self.gen_voice = self.voice_mask * mixed_input
                self.cost = mf.l1_loss(self.gen_voice, voice_input)

            elif data_type == 'mag_phase':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                #self.phase_loss = mf.l1_phase_loss(self.gen_voice[:, :, :, 1], voice_input[:, :, :, 1]) * phase_weight
                self.phase_loss = self.phase_loss_function(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            1]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'mag_phase_diff2':
                self.gen_voice_mag = tf.expand_dims(
                    self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0],
                    axis=3)
                self.mag_loss = mf.l1_loss(self.gen_voice_mag[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = self.phase_loss_function(
                    mf.phase_difference(mixed_input[:, :, :, 1],
                                        voice_input[:, :, :, 1]),
                    self.voice_mask[:, :, :, 1]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2
                self.gen_voice_phase = tf.expand_dims(
                    self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1],
                    axis=3)
                self.gen_voice = mf.concat(self.gen_voice_mag,
                                           self.gen_voice_phase)

            elif data_type == 'mag_phase_diff':
                self.gen_voice_mag = tf.expand_dims(
                    self.voice_mask[:, :, :, 0] * mixed_input[:, :, :, 0],
                    axis=3)
                self.gen_voice_phase = tf.expand_dims(
                    self.voice_mask[:, :, :, 1] + mixed_input[:, :, :, 1],
                    axis=3)
                self.gen_voice = mf.concat(self.gen_voice_mag,
                                           self.gen_voice_phase)
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.phase_loss = self.phase_loss_function(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            1]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                            voice_input[:, :, :, 0])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.cost = (self.real_loss + self.imag_loss) / 2

            elif data_type == 'mag_real_imag':
                self.gen_voice = self.voice_mask * mixed_input
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 0])
                self.real_loss = mf.l1_loss(self.gen_voice[:, :, :, 1],
                                            voice_input[:, :, :, 1])
                self.imag_loss = mf.l1_loss(self.gen_voice[:, :, :, 2],
                                            voice_input[:, :, :, 2])
                self.cost = (self.mag_loss + self.real_loss +
                             self.imag_loss) / 3

            elif data_type == 'mag_phase2':
                self.mag_mask = self.voice_mask[:, :, :, 0]
                self.phase_mask = tf.angle(
                    tf.complex(self.voice_mask[:, :, :, 1],
                               self.voice_mask[:, :, :, 2]))
                self.voice_mask = mf.concat(
                    tf.expand_dims(self.mag_mask, axis=3),
                    tf.expand_dims(self.phase_mask, axis=3))
                self.gen_mag = self.mag_mask * mixed_input[:, :, :, 0]
                self.gen_phase = self.phase_mask * tf.squeeze(mixed_phase,
                                                              axis=3)
                self.voice_phase = tf.angle(
                    tf.complex(self.voice_input[:, :, :, 1],
                               self.voice_input[:, :, :, 2]))
                self.gen_voice = mf.concat(
                    tf.expand_dims(self.gen_mag, axis=3),
                    tf.expand_dims(self.gen_phase, axis=3))
                self.mag_loss = mf.l1_loss(self.gen_mag, voice_input[:, :, :,
                                                                     0])
                self.phase_loss = self.phase_loss_function(
                    self.gen_phase, self.voice_phase) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'mag_phase_real_imag':
                self.gen_voice = self.voice_mask * mixed_input[:, :, :, 2:4]
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 2])
                self.phase_loss = self.phase_loss_function(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            3]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            elif data_type == 'complex_to_mag_phase':
                self.gen_voice = self.voice_mask * mixed_input[:, :, :, 2:4]
                self.mag_loss = mf.l1_loss(self.gen_voice[:, :, :, 0],
                                           voice_input[:, :, :, 2])
                self.phase_loss = self.phase_loss_function(
                    self.gen_voice[:, :, :, 1], voice_input[:, :, :,
                                                            3]) * phase_weight
                self.cost = (self.mag_loss + self.phase_loss) / 2

            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.5,
            )
            self.train_op = self.optimizer.minimize(self.cost)