def __init__(self, x_mag_spec_batch, lengths_batch, y_mag_spec_batch=None, theta_x_batch=None, theta_y_batch=None, behavior='train'): ''' behavior = 'train/validation/infer' ''' if behavior != self.infer: assert (y_mag_spec_batch is not None) assert (theta_x_batch is not None) assert (theta_y_batch is not None) self._log_bias = tf.get_variable( 'logbias', [1], trainable=FLAGS.PARAM.LOG_BIAS_TRAINABLE, initializer=tf.constant_initializer(FLAGS.PARAM.INIT_LOG_BIAS)) self._real_logbias = self._log_bias + FLAGS.PARAM.MIN_LOG_BIAS self._x_mag_spec = x_mag_spec_batch self._norm_x_mag_spec = norm_mag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_x_logmag_spec = norm_logmag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._y_mag_spec = y_mag_spec_batch self._norm_y_mag_spec = norm_mag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_y_logmag_spec = norm_logmag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._lengths = lengths_batch self._batch_size = tf.shape(self._lengths)[0] self._x_theta = theta_x_batch self._y_theta = theta_y_batch self._model_type = FLAGS.PARAM.MODEL_TYPE if FLAGS.PARAM.INPUT_TYPE == 'mag': self.net_input = self._norm_x_mag_spec elif FLAGS.PARAM.INPUT_TYPE == 'logmag': self.net_input = self._norm_x_logmag_spec if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_labels = self._norm_y_mag_spec elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_labels = self._norm_y_logmag_spec outputs = self.net_input if FLAGS.PARAM.INPUT_BN: with tf.variable_scope('Batch_Norm_Layer'): if_BRN = (FLAGS.PARAM.MVN_TYPE == 'BRN') if FLAGS.PARAM.SELF_BN: outputs = tf.layers.batch_normalization(outputs, training=True, renorm=if_BRN) else: outputs = tf.layers.batch_normalization( outputs, training=(behavior == self.train or behavior == self.validation), renorm=if_BRN) lstm_attn_cell = lstm_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def lstm_attn_cell(n_units, n_proj, act): return tf.contrib.rnn.DropoutWrapper( lstm_cell(n_units, n_proj, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) GRU_attn_cell = GRU_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def GRU_attn_cell(n_units, act): return tf.contrib.rnn.DropoutWrapper( GRU_cell(n_units, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) if FLAGS.PARAM.MODEL_TYPE.upper() == 'BLSTM': with tf.variable_scope('BLSTM'): lstm_fw_cell = tf.contrib.rnn.MultiRNNCell([ lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) lstm_bw_cell = tf.contrib.rnn.MultiRNNCell([ lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) fw_cell = lstm_fw_cell._cells bw_cell = lstm_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result if FLAGS.PARAM.MODEL_TYPE.upper() == 'BGRU': with tf.variable_scope('BGRU'): gru_fw_cell = tf.contrib.rnn.MultiRNNCell([ GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) gru_bw_cell = tf.contrib.rnn.MultiRNNCell([ GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) fw_cell = gru_fw_cell._cells bw_cell = gru_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result self.fw_final_state = fw_final_states self.bw_final_state = bw_final_states # print(fw_final_states[0][0].get_shape().as_list()) # print(np.shape(fw_final_states),np.shape(bw_final_states)) # region full connection get mask # calcu rnn output size in_size = FLAGS.PARAM.RNN_SIZE mask = None if self._model_type.upper()[0] == 'B': # bidirection rnn_output_num = FLAGS.PARAM.RNN_SIZE * 2 if FLAGS.PARAM.MODEL_TYPE == 'BLSTM' and ( not (FLAGS.PARAM.LSTM_num_proj is None)): rnn_output_num = 2 * FLAGS.PARAM.LSTM_num_proj in_size = rnn_output_num outputs = tf.reshape(outputs, [-1, in_size]) out_size = FLAGS.PARAM.OUTPUT_SIZE with tf.variable_scope('fullconnectOut'): weights = tf.get_variable( 'weights1', [in_size, out_size], initializer=tf.random_normal_initializer(stddev=0.01)) biases = tf.get_variable('biases1', [out_size], initializer=tf.constant_initializer( FLAGS.PARAM.INIT_MASK_VAL)) if FLAGS.PARAM.TIME_NOSOFTMAX_ATTENTION: with tf.variable_scope('fullconnectCoef'): weights_coef = tf.get_variable( 'weights_coef', [in_size, 1], initializer=tf.random_normal_initializer(mean=1.0, stddev=0.01)) biases_coef = tf.get_variable( 'biases_coef', [1], initializer=tf.constant_initializer(0.0)) raw_mask = tf.reshape( tf.matmul(outputs, weights) + biases, [self._batch_size, -1, FLAGS.PARAM.OUTPUT_SIZE ]) # [batch,time,fre] batch_coef_vec = tf.nn.relu( tf.reshape( tf.matmul(outputs, weights_coef) + biases_coef, [self._batch_size, -1])) # [batch, time] mask = tf.multiply( raw_mask, tf.reshape(batch_coef_vec, [self._batch_size, -1, 1])) else: if FLAGS.PARAM.POST_BN: linear_out = tf.matmul(outputs, weights) with tf.variable_scope('POST_Batch_Norm_Layer'): if_BRN = (FLAGS.PARAM.MVN_TYPE == 'BRN') if FLAGS.PARAM.SELF_BN: linear_out = tf.layers.batch_normalization( linear_out, training=True, renorm=if_BRN) else: linear_out = tf.layers.batch_normalization( linear_out, training=(behavior == self.train or behavior == self.validation), renorm=if_BRN) weights2 = tf.get_variable( 'weights1', [out_size, out_size], initializer=tf.random_normal_initializer(stddev=0.01)) biases2 = tf.get_variable( 'biases1', [out_size], initializer=tf.constant_initializer( FLAGS.PARAM.INIT_MASK_VAL)) linear_out = tf.matmul(linear_out, weights2) + biases2 else: linear_out = tf.matmul(outputs, weights) + biases mask = linear_out if FLAGS.PARAM.ReLU_MASK: mask = tf.nn.relu(linear_out) # endregion self._mask = tf.reshape( mask, [self._batch_size, -1, FLAGS.PARAM.OUTPUT_SIZE]) if FLAGS.PARAM.TRAINING_MASK_POSITION == 'mag': self._y_estimation = self._mask * (self._norm_x_mag_spec + FLAGS.PARAM.SPEC_EST_BIAS) elif FLAGS.PARAM.TRAINING_MASK_POSITION == 'logmag': self._y_estimation = self._mask * (self._norm_x_logmag_spec + FLAGS.PARAM.SPEC_EST_BIAS) # region get infer spec if FLAGS.PARAM.DECODING_MASK_POSITION == 'mag': self._y_mag_estimation = rm_norm_mag_spec( self._mask * (self._norm_x_mag_spec + FLAGS.PARAM.SPEC_EST_BIAS), FLAGS.PARAM.MAG_NORM_MAX) elif FLAGS.PARAM.DECODING_MASK_POSITION == 'logmag': self._y_mag_estimation = rm_norm_logmag_spec( self._mask * (self._norm_x_logmag_spec + FLAGS.PARAM.SPEC_EST_BIAS), FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) ''' _y_mag_estimation is estimated mag_spec _y_estimation is loss_targe, mag_sepec or logmag_spec ''' # endregion # region prepare y_estimation if FLAGS.PARAM.TRAINING_MASK_POSITION != FLAGS.PARAM.LABEL_TYPE: if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_estimation = normedLogmag2normedMag( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_estimation = normedMag2normedLogmag( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) # endregion # region CBHG if FLAGS.PARAM.USE_CBHG_POST_PROCESSING: cbhg_kernels = 8 # All kernel sizes from 1 to cbhg_kernels will be used in the convolution bank of CBHG to act as "K-grams" cbhg_conv_channels = 128 # Channels of the convolution bank cbhg_pool_size = 2 # pooling size of the CBHG cbhg_projection = 256 # projection channels of the CBHG (1st projection, 2nd is automatically set to num_mels) cbhg_projection_kernel_size = 3 # kernel_size of the CBHG projections cbhg_highwaynet_layers = 4 # Number of HighwayNet layers cbhg_highway_units = 128 # Number of units used in HighwayNet fully connected layers cbhg_rnn_units = 128 # Number of GRU units used in bidirectional RNN of CBHG block. CBHG output is 2x rnn_units in shape batch_norm_position = 'before' # is_training = True is_training = bool(behavior == self.train) post_cbhg = CBHG(cbhg_kernels, cbhg_conv_channels, cbhg_pool_size, [cbhg_projection, FLAGS.PARAM.OUTPUT_SIZE], cbhg_projection_kernel_size, cbhg_highwaynet_layers, cbhg_highway_units, cbhg_rnn_units, batch_norm_position, is_training, name='CBHG_postnet') #[batch_size, decoder_steps(mel_frames), cbhg_channels] self._cbhg_inputs_y_est = self._y_estimation cbhg_outputs = post_cbhg(self._y_estimation, None) frame_projector = FrameProjection(FLAGS.PARAM.OUTPUT_SIZE, scope='CBHG_proj_to_spec') self._y_estimation = frame_projector(cbhg_outputs) if FLAGS.PARAM.DECODING_MASK_POSITION != FLAGS.PARAM.TRAINING_MASK_POSITION: print( 'DECODING_MASK_POSITION must be equal to TRAINING_MASK_POSITION when use CBHG post processing.' ) exit(-1) if FLAGS.PARAM.DECODING_MASK_POSITION == 'mag': self._y_mag_estimation = rm_norm_mag_spec( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX) elif FLAGS.PARAM.DECODING_MASK_POSITION == 'logmag': self._y_mag_estimation = rm_norm_logmag_spec( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) # endregion self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) if behavior == self.infer: return # region get labels LOSS # Labels if FLAGS.PARAM.MASK_TYPE == 'PSM': self._y_labels *= tf.cos(self._x_theta - self._y_theta) elif FLAGS.PARAM.MASK_TYPE == 'fixPSM': self._y_labels *= (1.0 + tf.cos(self._x_theta - self._y_theta)) * 0.5 elif FLAGS.PARAM.MASK_TYPE == 'AcutePM': self._y_labels *= tf.nn.relu(tf.cos(self._x_theta - self._y_theta)) elif FLAGS.PARAM.MASK_TYPE == 'PowFixPSM': self._y_labels *= tf.pow( tf.abs((1.0 + tf.cos(self._x_theta - self._y_theta)) * 0.5), FLAGS.PARAM.POW_FIX_PSM_COEF) elif FLAGS.PARAM.MASK_TYPE == 'IRM': pass else: tf.logging.error('Mask type error.') exit(-1) # LOSS if FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'SPEC_MSE': # log_mag and mag MSE self._loss = loss.reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) if FLAGS.PARAM.USE_CBHG_POST_PROCESSING: if FLAGS.PARAM.DOUBLE_LOSS: self._loss = FLAGS.PARAM.CBHG_LOSS_COEF1 * loss.reduce_sum_frame_batchsize_MSE( self._cbhg_inputs_y_est, self._y_labels ) + FLAGS.PARAM.CBHG_LOSS_COEF2 * self._loss elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MFCC_SPEC_MSE': self._loss1, self._loss2 = loss.balanced_MFCC_AND_SPEC_MSE( self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF * self._loss1 + FLAGS.PARAM.MFCC_LOSS_COEF * self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MEL_MAG_MSE': self._loss1, self._loss2 = loss.balanced_MEL_AND_SPEC_MSE( self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF * self._loss1 + FLAGS.PARAM.MEL_LOSS_COEF * self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_LOWF_EN": self._loss = loss.reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "FAIR_SPEC_MSE": self._loss = loss.fair_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_FLEXIBLE_POW_C": self._loss = loss.reduce_sum_frame_batchsize_MSE_EmphasizeLowerValue( self._y_estimation, self._y_labels, FLAGS.PARAM.POW_COEF) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "RELATED_MSE": self._loss = loss.relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.RELATED_MSE_IGNORE_TH) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE2": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v2( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG, FLAGS.PARAM.LINEAR_BROKER, ) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE3": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v3( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE4": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v4( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE5": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v5( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE6": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v6( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE7": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v7( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_A1, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_A2, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE8": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v8( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE_USE_COS": self._loss = loss.cos_auto_ingore_relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.COS_AUTO_RELATED_MSE_W) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MEL_AUTO_RELATED_MSE': # type(y_estimation) = FLAGS.PARAM.LABEL_TYPE self._loss = loss.MEL_AUTO_RELATIVE_MSE( self._y_estimation, self._norm_y_mag_spec, FLAGS.PARAM.MEL_NUM, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) else: print('Loss type error.') exit(-1) # endregion if behavior == self.validation: ''' val model cannot train. ''' return self._lr = tf.Variable(0.0, trainable=False) #TODO tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), FLAGS.PARAM.CLIP_NORM) optimizer = tf.train.AdamOptimizer(self.lr) #optimizer = tf.train.GradientDescentOptimizer(self.lr) self._train_op = optimizer.apply_gradients(zip(grads, tvars)) self._new_lr = tf.placeholder(tf.float32, shape=[], name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr)
def __init__(self, x_mag_spec_batch, lengths_batch, y_mag_spec_batch=None, theta_x_batch=None, theta_y_batch=None, behavior='train'): ''' behavior = 'train/validation/infer' ''' if behavior != self.infer: assert(y_mag_spec_batch is not None) assert(theta_x_batch is not None) assert(theta_y_batch is not None) self._log_bias = tf.get_variable('logbias', [1], trainable=FLAGS.PARAM.LOG_BIAS_TRAINABLE, initializer=tf.constant_initializer(FLAGS.PARAM.INIT_LOG_BIAS)) self._real_logbias = self._log_bias + FLAGS.PARAM.MIN_LOG_BIAS self._x_mag_spec = x_mag_spec_batch self._norm_x_mag_spec = norm_mag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_x_logmag_spec = norm_logmag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._y_mag_spec = y_mag_spec_batch self._norm_y_mag_spec = norm_mag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_y_logmag_spec = norm_logmag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._lengths = lengths_batch self._batch_size = tf.shape(self._lengths)[0] self._x_theta = theta_x_batch self._y_theta = theta_y_batch self._model_type = FLAGS.PARAM.MODEL_TYPE if FLAGS.PARAM.INPUT_TYPE == 'mag': self.net_input = self._norm_x_mag_spec elif FLAGS.PARAM.INPUT_TYPE == 'logmag': self.net_input = self._norm_x_logmag_spec if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_labels = self._norm_y_mag_spec elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_labels = self._norm_y_logmag_spec outputs = self.net_input lstm_attn_cell = lstm_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def lstm_attn_cell(n_units, n_proj, act): return tf.contrib.rnn.DropoutWrapper(lstm_cell(n_units, n_proj, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) GRU_attn_cell = GRU_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def GRU_attn_cell(n_units, act): return tf.contrib.rnn.DropoutWrapper(GRU_cell(n_units, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) if FLAGS.PARAM.MODEL_TYPE.upper() == 'BLSTM': with tf.variable_scope('BLSTM'): lstm_fw_cell = tf.contrib.rnn.MultiRNNCell( [lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER)], state_is_tuple=True) lstm_bw_cell = tf.contrib.rnn.MultiRNNCell( [lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER)], state_is_tuple=True) fw_cell = lstm_fw_cell._cells bw_cell = lstm_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result if FLAGS.PARAM.MODEL_TYPE.upper() == 'BGRU': with tf.variable_scope('BGRU'): gru_fw_cell = tf.contrib.rnn.MultiRNNCell( [GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER)], state_is_tuple=True) gru_bw_cell = tf.contrib.rnn.MultiRNNCell( [GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER)], state_is_tuple=True) fw_cell = gru_fw_cell._cells bw_cell = gru_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result # region full connection get mask # calcu rnn output size in_size = FLAGS.PARAM.RNN_SIZE mask = None if self._model_type.upper()[0] == 'B': # bidirection rnn_output_num = FLAGS.PARAM.RNN_SIZE*2 if FLAGS.PARAM.MODEL_TYPE == 'BLSTM' and (not (FLAGS.PARAM.LSTM_num_proj is None)): rnn_output_num = 2*FLAGS.PARAM.LSTM_num_proj in_size = rnn_output_num outputs = tf.reshape(outputs, [-1, in_size]) out_size = FLAGS.PARAM.OUTPUT_SIZE with tf.variable_scope('fullconnectOut'): weights = tf.get_variable('weights1', [in_size, out_size], initializer=tf.random_normal_initializer(stddev=0.01)) biases = tf.get_variable('biases1', [out_size], initializer=tf.constant_initializer(0.0)) mask = tf.nn.relu(tf.matmul(outputs, weights) + biases) self._mask = tf.reshape( mask, [self._batch_size, -1, FLAGS.PARAM.OUTPUT_SIZE]) # endregion outputs = tf.reshape(outputs, [self._batch_size, -1, in_size]) # region Apply Noise Threshold Function on Mask if FLAGS.PARAM.THRESHOLD_FUNC is not None: # use noise threshold if FLAGS.PARAM.THRESHOLD_POS == FLAGS.PARAM.THRESHOLD_ON_MASK: self._mask, self._threshold = threshold_feature(self._mask, outputs, self._batch_size, in_size) elif FLAGS.PARAM.THRESHOLD_POS == FLAGS.PARAM.THRESHOLD_ON_SPEC: pass else: print('Threshold position error!') exit(-1) # endregion # region prepare y_estimation and y_labels if FLAGS.PARAM.TRAINING_MASK_POSITION == 'mag': self._y_estimation = self._mask*self._norm_x_mag_spec elif FLAGS.PARAM.TRAINING_MASK_POSITION == 'logmag': self._y_estimation = self._mask*self._norm_x_logmag_spec if FLAGS.PARAM.MASK_TYPE == 'PSM': self._y_labels *= tf.cos(self._x_theta-self._y_theta) elif FLAGS.PARAM.MASK_TYPE == 'IRM': pass else: tf.logging.error('Mask type error.') exit(-1) # region Apply Noise Threshold Function on Spec(log or mag) if FLAGS.PARAM.THRESHOLD_FUNC is not None: # use noise threshold if FLAGS.PARAM.THRESHOLD_POS == FLAGS.PARAM.THRESHOLD_ON_MASK: pass elif FLAGS.PARAM.THRESHOLD_POS == FLAGS.PARAM.THRESHOLD_ON_SPEC: self._y_estimation, self._threshold = threshold_feature(self._y_estimation, outputs, self._batch_size, in_size) # endregion # region get infer spec if FLAGS.PARAM.DECODING_MASK_POSITION != FLAGS.PARAM.TRAINING_MASK_POSITION: print('Error, DECODING_MASK_POSITION should be equal to TRAINING_MASK_POSITION when use thresohold model.') if FLAGS.PARAM.DECODING_MASK_POSITION == 'mag': self._y_mag_estimation = rm_norm_mag_spec(self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX) elif FLAGS.PARAM.DECODING_MASK_POSITION == 'logmag': self._y_mag_estimation = rm_norm_logmag_spec(self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) ''' _y_mag_estimation is estimated mag_spec _y_estimation is loss_targe, mag_sepec or logmag_spec ''' # endregion if FLAGS.PARAM.TRAINING_MASK_POSITION != FLAGS.PARAM.LABEL_TYPE: if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_estimation = normedLogmag2normedMag(self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_estimation = normedMag2normedLogmag(self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) # endregion self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) if behavior == self.infer: return # region get LOSS if FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'SPEC_MSE': # log_mag and mag MSE self._loss = loss.reduce_sum_frame_batchsize_MSE(self._y_estimation,self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MFCC_SPEC_MSE': self._loss1, self._loss2 = loss.balanced_MFCC_AND_SPEC_MSE(self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF*self._loss1 + FLAGS.PARAM.MFCC_LOSS_COEF*self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MEL_MAG_MSE': self._loss1, self._loss2 = loss.balanced_MEL_AND_SPEC_MSE(self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF*self._loss1 + FLAGS.PARAM.MEL_LOSS_COEF*self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_LOWF_EN": self._loss = loss.reduce_sum_frame_batchsize_MSE(self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "FAIR_SPEC_MSE": self._loss = loss.fair_reduce_sum_frame_batchsize_MSE(self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_FLEXIBLE_POW_C": self._loss = loss.reduce_sum_frame_batchsize_MSE_EmphasizeLowerValue(self._y_estimation, self._y_labels, FLAGS.PARAM.POW_COEF) else: print('Loss type error.') exit(-1) # endregion if behavior == self.validation: ''' val model cannot train. ''' return self._lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), FLAGS.PARAM.CLIP_NORM) optimizer = tf.train.AdamOptimizer(self.lr) #optimizer = tf.train.GradientDescentOptimizer(self.lr) self._train_op = optimizer.apply_gradients(zip(grads, tvars)) self._new_lr = tf.placeholder( tf.float32, shape=[], name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr)
def __init__(self, x_mag_spec_batch, lengths_batch, y_mag_spec_batch=None, theta_x_batch=None, theta_y_batch=None, behavior='train'): ''' behavior = 'train/validation/infer' ''' if behavior != self.infer: assert (y_mag_spec_batch is not None) assert (theta_x_batch is not None) assert (theta_y_batch is not None) self._log_bias = tf.get_variable( 'logbias', [1], trainable=FLAGS.PARAM.LOG_BIAS_TRAINABLE, initializer=tf.constant_initializer(FLAGS.PARAM.INIT_LOG_BIAS)) self._real_logbias = self._log_bias + FLAGS.PARAM.MIN_LOG_BIAS self._x_mag_spec = x_mag_spec_batch self._norm_x_mag_spec = norm_mag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_x_logmag_spec = norm_logmag_spec(self._x_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._y_mag_spec = y_mag_spec_batch self._norm_y_mag_spec = norm_mag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX) self._norm_y_logmag_spec = norm_logmag_spec(self._y_mag_spec, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) self._lengths = lengths_batch self._batch_size = tf.shape(self._lengths)[0] self._x_theta = theta_x_batch self._y_theta = theta_y_batch self._model_type = FLAGS.PARAM.MODEL_TYPE if FLAGS.PARAM.INPUT_TYPE == 'mag': self.net_input = self._norm_x_mag_spec elif FLAGS.PARAM.INPUT_TYPE == 'logmag': self.net_input = self._norm_x_logmag_spec if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_labels = self._norm_y_mag_spec elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_labels = self._norm_y_logmag_spec outputs = self.net_input # [batch, time, ...] if FLAGS.PARAM.OUTPUTS_LATER_SHIFT_FRAMES > 0: padding_zeros = tf.zeros([ self._batch_size, FLAGS.PARAM.OUTPUTS_LATER_SHIFT_FRAMES, tf.shape(outputs)[-1] ]) outputs = tf.concat([outputs, padding_zeros], -2) if FLAGS.PARAM.MODEL_TYPE.upper() in [ "BGRU", "BLSTM", "UNIGRU", "UNILSTM" ]: # in: outputs [batch, time, ...] # out: outputs [batch, time, ...], insize: shape(outputs)[-1] lstm_attn_cell = lstm_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def lstm_attn_cell(n_units, n_proj, act): return tf.contrib.rnn.DropoutWrapper( lstm_cell(n_units, n_proj, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) GRU_attn_cell = GRU_cell if behavior != self.infer and FLAGS.PARAM.KEEP_PROB < 1.0: def GRU_attn_cell(n_units, act): return tf.contrib.rnn.DropoutWrapper( GRU_cell(n_units, act), output_keep_prob=FLAGS.PARAM.KEEP_PROB) if FLAGS.PARAM.MODEL_TYPE.upper() == 'UNIGRU': with tf.variable_scope('UNI_GRU'): gru_cell = tf.contrib.rnn.MultiRNNCell([ GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) # _cell = gru_cell._cells result = tf.nn.dynamic_rnn( gru_cell, outputs, dtype=tf.float32, sequence_length=self._lengths, ) outputs, final_states = result if FLAGS.PARAM.MODEL_TYPE.upper() == 'UNILSTM': with tf.variable_scope('UNI_LSTM'): lstm_cells__t = tf.contrib.rnn.MultiRNNCell( [ lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) # _cell = lstm_cell._cells result = tf.nn.dynamic_rnn( lstm_cells__t, outputs, dtype=tf.float32, sequence_length=self._lengths, ) outputs, final_states = result if FLAGS.PARAM.MODEL_TYPE.upper() == 'BLSTM': with tf.variable_scope('BLSTM'): lstm_fw_cell = tf.contrib.rnn.MultiRNNCell( [ lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) lstm_bw_cell = tf.contrib.rnn.MultiRNNCell( [ lstm_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_num_proj, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) fw_cell = lstm_fw_cell._cells bw_cell = lstm_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result if FLAGS.PARAM.MODEL_TYPE.upper() == 'BGRU': with tf.variable_scope('BGRU'): gru_fw_cell = tf.contrib.rnn.MultiRNNCell( [ GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) gru_bw_cell = tf.contrib.rnn.MultiRNNCell( [ GRU_attn_cell(FLAGS.PARAM.RNN_SIZE, FLAGS.PARAM.LSTM_ACTIVATION) for _ in range(FLAGS.PARAM.RNN_LAYER) ], state_is_tuple=True) fw_cell = gru_fw_cell._cells bw_cell = gru_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=fw_cell, cells_bw=bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result # self.fw_final_state = fw_final_states # self.bw_final_state = bw_final_states # print(fw_final_states[0][0].get_shape().as_list()) # print(np.shape(fw_final_states),np.shape(bw_final_states)) # calcu rnn output size in_size = FLAGS.PARAM.RNN_SIZE mask = None if self._model_type.upper()[0] == 'B': # bidirection rnn_output_num = FLAGS.PARAM.RNN_SIZE * 2 if FLAGS.PARAM.MODEL_TYPE == 'BLSTM' and ( not (FLAGS.PARAM.LSTM_num_proj is None)): rnn_output_num = 2 * FLAGS.PARAM.LSTM_num_proj in_size = rnn_output_num elif FLAGS.PARAM.MODEL_TYPE.upper() == 'TRANSFORMER': # in: outputs [batch, time, ...] # out: outputs [batch, time, ...], insize: shape(outputs)[-1] is_training = (behavior == self.train) n_self_att_blocks = FLAGS.PARAM.n_self_att_blocks d_model = FLAGS.PARAM.RNN_SIZE num_att_heads = FLAGS.PARAM.num_att_heads d_positionwise_FC = FLAGS.PARAM.d_positionwise_FC with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE): # inputs embedding trans = outputs trans *= FLAGS.PARAM.FFT_DOT**0.5 # scale trans += transformer_utils.positional_encoding( trans, 2000) # TODO fixed length? trans = tf.layers.dropout(trans, 1.0 - FLAGS.PARAM.KEEP_PROB, training=is_training) trans = tf.layers.dense(trans, d_model, use_bias=False) ## Blocks, for i in range(n_self_att_blocks): with tf.variable_scope("blocks_{}".format(i), reuse=tf.AUTO_REUSE): # self-attention trans = transformer_utils.multihead_attention( queries=trans, keys=trans, values=trans, d_model=d_model, KV_lengths=self.lengths, Q_lengths=self.lengths, num_heads=num_att_heads, dropout_rate=1.0 - FLAGS.PARAM.KEEP_PROB, training=is_training, causality=False) # position-wise feedforward trans = transformer_utils.positionwise_FC( trans, num_units=[d_positionwise_FC, d_model]) outputs = trans # [batch, time_src, d_model] in_size = d_model else: raise ValueError('Unknown model type %s.' % FLAGS.PARAM.MODEL_TYPE) if FLAGS.PARAM.OUTPUTS_LATER_SHIFT_FRAMES > 0: outputs = tf.slice(outputs, [0, FLAGS.PARAM.OUTPUTS_LATER_SHIFT_FRAMES, 0], [-1, -1, -1]) if FLAGS.PARAM.POST_1D_CNN: outputs = tf.layers.conv1d(outputs, filters=in_size, use_bias=True, kernel_size=FLAGS.PARAM.CNN_1D_WIDTH, padding="same", reuse=tf.AUTO_REUSE) # region full connection get mask outputs = tf.reshape(outputs, [-1, in_size]) out_size = FLAGS.PARAM.OUTPUT_SIZE with tf.variable_scope('fullconnectOut'): weights = tf.get_variable( 'weights1', [in_size, out_size], initializer=tf.random_normal_initializer(stddev=0.01)) biases = tf.get_variable('biases1', [out_size], initializer=tf.constant_initializer( FLAGS.PARAM.INIT_MASK_VAL)) linear_out = tf.matmul(outputs, weights) + biases mask = linear_out if FLAGS.PARAM.ReLU_MASK: mask = tf.nn.relu(linear_out) # endregion full connection self._mask = tf.reshape( mask, [self._batch_size, -1, FLAGS.PARAM.OUTPUT_SIZE]) if FLAGS.PARAM.TRAINING_MASK_POSITION == 'mag': self._y_estimation = self._mask * (self._norm_x_mag_spec + FLAGS.PARAM.SPEC_EST_BIAS) elif FLAGS.PARAM.TRAINING_MASK_POSITION == 'logmag': self._y_estimation = self._mask * (self._norm_x_logmag_spec + FLAGS.PARAM.SPEC_EST_BIAS) # region get infer spec if FLAGS.PARAM.DECODING_MASK_POSITION == 'mag': self._y_mag_estimation = rm_norm_mag_spec( self._mask * (self._norm_x_mag_spec + FLAGS.PARAM.SPEC_EST_BIAS), FLAGS.PARAM.MAG_NORM_MAX) elif FLAGS.PARAM.DECODING_MASK_POSITION == 'logmag': self._y_mag_estimation = rm_norm_logmag_spec( self._mask * (self._norm_x_logmag_spec + FLAGS.PARAM.SPEC_EST_BIAS), FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) ''' _y_mag_estimation is estimated mag_spec _y_estimation is loss_targe, mag_sepec or logmag_spec ''' # endregion # region prepare y_estimation if FLAGS.PARAM.TRAINING_MASK_POSITION != FLAGS.PARAM.LABEL_TYPE: if FLAGS.PARAM.LABEL_TYPE == 'mag': self._y_estimation = normedLogmag2normedMag( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) elif FLAGS.PARAM.LABEL_TYPE == 'logmag': self._y_estimation = normedMag2normedLogmag( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) # endregion # region CBHG if FLAGS.PARAM.USE_CBHG_POST_PROCESSING: cbhg_kernels = 8 # All kernel sizes from 1 to cbhg_kernels will be used in the convolution bank of CBHG to act as "K-grams" cbhg_conv_channels = 128 # Channels of the convolution bank cbhg_pool_size = 2 # pooling size of the CBHG cbhg_projection = 256 # projection channels of the CBHG (1st projection, 2nd is automatically set to num_mels) cbhg_projection_kernel_size = 3 # kernel_size of the CBHG projections cbhg_highwaynet_layers = 4 # Number of HighwayNet layers cbhg_highway_units = 128 # Number of units used in HighwayNet fully connected layers cbhg_rnn_units = 128 # Number of GRU units used in bidirectional RNN of CBHG block. CBHG output is 2x rnn_units in shape batch_norm_position = 'before' # is_training = True is_training = bool(behavior == self.train) post_cbhg = CBHG(cbhg_kernels, cbhg_conv_channels, cbhg_pool_size, [cbhg_projection, FLAGS.PARAM.OUTPUT_SIZE], cbhg_projection_kernel_size, cbhg_highwaynet_layers, cbhg_highway_units, cbhg_rnn_units, batch_norm_position, is_training, name='CBHG_postnet') #[batch_size, decoder_steps(mel_frames), cbhg_channels] self._cbhg_inputs_y_est = self._y_estimation cbhg_outputs = post_cbhg(self._y_estimation, None) frame_projector = FrameProjection(FLAGS.PARAM.OUTPUT_SIZE, scope='CBHG_proj_to_spec') self._y_estimation = frame_projector(cbhg_outputs) if FLAGS.PARAM.DECODING_MASK_POSITION != FLAGS.PARAM.TRAINING_MASK_POSITION: print( 'DECODING_MASK_POSITION must be equal to TRAINING_MASK_POSITION when use CBHG post processing.' ) exit(-1) if FLAGS.PARAM.DECODING_MASK_POSITION == 'mag': self._y_mag_estimation = rm_norm_mag_spec( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX) elif FLAGS.PARAM.DECODING_MASK_POSITION == 'logmag': self._y_mag_estimation = rm_norm_logmag_spec( self._y_estimation, FLAGS.PARAM.MAG_NORM_MAX, self._log_bias, FLAGS.PARAM.MIN_LOG_BIAS) # endregion self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) if behavior == self.infer: return # region get labels LOSS # Labels if FLAGS.PARAM.MASK_TYPE == 'PSM': self._y_labels *= tf.cos(self._x_theta - self._y_theta) elif FLAGS.PARAM.MASK_TYPE == 'fixPSM': self._y_labels *= (1.0 + tf.cos(self._x_theta - self._y_theta)) * 0.5 elif FLAGS.PARAM.MASK_TYPE == 'AcutePM': self._y_labels *= tf.nn.relu(tf.cos(self._x_theta - self._y_theta)) elif FLAGS.PARAM.MASK_TYPE == 'PowFixPSM': self._y_labels *= tf.pow( tf.abs((1.0 + tf.cos(self._x_theta - self._y_theta)) * 0.5), FLAGS.PARAM.POW_FIX_PSM_COEF) elif FLAGS.PARAM.MASK_TYPE == 'IRM': pass else: tf.logging.error('Mask type error.') exit(-1) # LOSS if FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'SPEC_MSE': # log_mag and mag MSE self._loss = loss.reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) if FLAGS.PARAM.USE_CBHG_POST_PROCESSING: if FLAGS.PARAM.DOUBLE_LOSS: self._loss = FLAGS.PARAM.CBHG_LOSS_COEF1 * loss.reduce_sum_frame_batchsize_MSE( self._cbhg_inputs_y_est, self._y_labels ) + FLAGS.PARAM.CBHG_LOSS_COEF2 * self._loss elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MFCC_SPEC_MSE': self._loss1, self._loss2 = loss.balanced_MFCC_AND_SPEC_MSE( self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF * self._loss1 + FLAGS.PARAM.MFCC_LOSS_COEF * self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MEL_MAG_MSE': self._loss1, self._loss2 = loss.balanced_MEL_AND_SPEC_MSE( self._y_estimation, self._y_labels, self._y_mag_estimation, self._y_mag_spec) self._loss = FLAGS.PARAM.SPEC_LOSS_COEF * self._loss1 + FLAGS.PARAM.MEL_LOSS_COEF * self._loss2 elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_LOWF_EN": self._loss = loss.reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "FAIR_SPEC_MSE": self._loss = loss.fair_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "SPEC_MSE_FLEXIBLE_POW_C": self._loss = loss.reduce_sum_frame_batchsize_MSE_EmphasizeLowerValue( self._y_estimation, self._y_labels, FLAGS.PARAM.POW_COEF) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "RELATED_MSE": self._loss = loss.relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.RELATED_MSE_IGNORE_TH) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE2": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v2( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG, FLAGS.PARAM.LINEAR_BROKER, ) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE3": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v3( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS3_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE4": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v4( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE5": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v5( self._y_estimation, self._y_labels) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE6": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v6( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS6_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE7": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v7( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_A1, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_A2, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS7_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE8": self._loss = loss.auto_ingore_relative_reduce_sum_frame_batchsize_MSE_v8( self._y_estimation, self._y_labels, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_A, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_B, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_C1, FLAGS.PARAM.AUTO_RELATIVE_LOSS8_C2) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == "AUTO_RELATED_MSE_USE_COS": self._loss = loss.cos_auto_ingore_relative_reduce_sum_frame_batchsize_MSE( self._y_estimation, self._y_labels, FLAGS.PARAM.COS_AUTO_RELATED_MSE_W) elif FLAGS.PARAM.LOSS_FUNC_FOR_MAG_SPEC == 'MEL_AUTO_RELATED_MSE': # type(y_estimation) = FLAGS.PARAM.LABEL_TYPE self._loss = loss.MEL_AUTO_RELATIVE_MSE( self._y_estimation, self._norm_y_mag_spec, FLAGS.PARAM.MEL_NUM, FLAGS.PARAM.AUTO_RELATED_MSE_AXIS_FIT_DEG) else: print('Loss type error.') exit(-1) # endregion if behavior == self.validation: ''' val model cannot train. ''' return self._lr = tf.Variable(0.0, trainable=False) # TODO tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), FLAGS.PARAM.CLIP_NORM) optimizer = tf.train.AdamOptimizer(self.lr) #optimizer = tf.train.GradientDescentOptimizer(self.lr) self._train_op = optimizer.apply_gradients(zip(grads, tvars)) self._new_lr = tf.placeholder(tf.float32, shape=[], name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr)