def build(self): # create sub-modules encoder = hparams.get_encoder()(self, 'encoder') # =================== # build the model input_shape = [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, None, hparams.FEATURE_SIZE ] s_src_signals = tf.placeholder(hparams.COMPLEXX, input_shape, name='source_signal') s_dropout_keep = tf.placeholder(hparams.FLOATX, [], name='dropout_keep') reger = hparams.get_regularizer() with tf.variable_scope('global', regularizer=reger): # TODO add mixing coeff ? # get mixed signal s_mixed_signals = tf.reduce_sum(s_src_signals, axis=1) s_src_signals_pwr = tf.abs(s_src_signals) s_mixed_signals_phase = tf.atan2(tf.imag(s_mixed_signals), tf.real(s_mixed_signals)) s_mixed_signals_power = tf.abs(s_mixed_signals) s_mixed_signals_log = tf.log1p(s_mixed_signals_power) # int[B, T, F] # float[B, T, F, E] s_embed = encoder(s_mixed_signals_log) s_embed_flat = tf.reshape( s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE]) # TODO make attractor estimator a submodule ? estimator = hparams.get_estimator(hparams.TRAIN_ESTIMATOR_METHOD)( self, 'train_estimator') s_attractors = estimator(s_embed, s_src_pwr=s_src_signals_pwr, s_mix_pwr=s_mixed_signals_power) using_same_method = (hparams.INFER_ESTIMATOR_METHOD == hparams.TRAIN_ESTIMATOR_METHOD) if using_same_method: s_valid_attractors = s_attractors else: valid_estimator = hparams.get_estimator( hparams.INFER_ESTIMATOR_METHOD)(self, 'infer_estimator') assert not valid_estimator.USE_TRUTH s_valid_attractors = valid_estimator(s_embed) separator = hparams.get_separator(hparams.SEPARATOR_TYPE)( self, 'separator') s_separated_signals_pwr = separator(s_mixed_signals_power, s_attractors, s_embed_flat) if using_same_method: s_separated_signals_pwr_valid = s_separated_signals_pwr else: s_separated_signals_pwr_valid = separator( s_mixed_signals_power, s_valid_attractors, s_embed_flat) # use mixture phase and estimated power to get separated signal s_mixed_signals_phase = tf.expand_dims(s_mixed_signals_phase, 1) s_separated_signals = tf.complex( tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr, tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr) # loss and SNR for training # s_train_loss, v_perms, s_perm_sets = ops.pit_mse_loss( # s_src_signals_pwr, s_separated_signals_pwr) s_train_loss, v_perms, s_perm_sets = ops.pit_mse_loss( s_src_signals, s_separated_signals) # resolve permutation s_perm_idxs = tf.stack([ tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1), [1, hparams.MAX_N_SIGNAL]), tf.gather(v_perms, s_perm_sets) ], axis=2) s_perm_idxs = tf.reshape( s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2]) s_separated_signals = tf.gather_nd(s_separated_signals, s_perm_idxs) s_separated_signals = tf.reshape(s_separated_signals, [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE ]) s_train_snr = tf.reduce_mean( ops.batch_snr(s_src_signals, s_separated_signals)) # ^ for validation / inference s_valid_loss, v_perms, s_perm_sets = ops.pit_mse_loss( s_src_signals_pwr, s_separated_signals_pwr_valid) s_perm_idxs = tf.stack([ tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1), [1, hparams.MAX_N_SIGNAL]), tf.gather(v_perms, s_perm_sets) ], axis=2) s_perm_idxs = tf.reshape( s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2]) s_separated_signals_pwr_valid_pit = tf.gather_nd( s_separated_signals_pwr_valid, s_perm_idxs) s_separated_signals_pwr_valid_pit = tf.reshape( s_separated_signals_pwr_valid_pit, [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE ]) s_separated_signals_valid = tf.complex( tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr_valid_pit, tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr_valid_pit) s_separated_signals_infer = tf.complex( tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr_valid, tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr_valid) s_valid_snr = tf.reduce_mean( ops.batch_snr(s_src_signals, s_separated_signals_valid)) # =============== # prepare summary # TODO add impl & summary for word error rate with tf.name_scope('train_summary'): s_loss_summary_t = tf.summary.scalar('loss', s_train_loss) s_snr_summary_t = tf.summary.scalar('SNR', s_train_snr) with tf.name_scope('valid_summary'): s_loss_summary_v = tf.summary.scalar('loss', s_valid_loss) s_snr_summary_v = tf.summary.scalar('SNR', s_valid_snr) # apply optimizer ozer = hparams.get_optimizer()(learn_rate=self.v_learn_rate, lr_decay=hparams.LR_DECAY) v_params_li = tf.trainable_variables() r_apply_grads = ozer.compute_gradients(s_train_loss, v_params_li) if hparams.GRAD_CLIP_THRES is not None: r_apply_grads = [(tf.clip_by_value(g, -hparams.GRAD_CLIP_THRES, hparams.GRAD_CLIP_THRES), v) for g, v in r_apply_grads if g is not None] self.op_sgd_step = ozer.apply_gradients(r_apply_grads) self.op_init_params = tf.variables_initializer(v_params_li) self.op_init_states = tf.variables_initializer( list(self.s_states_di.values())) self.train_feed_keys = [s_src_signals, s_dropout_keep] train_summary = tf.summary.merge([s_loss_summary_t, s_snr_summary_t]) self.train_fetches = [ train_summary, dict(loss=s_train_loss, SNR=s_train_snr), self.op_sgd_step ] self.valid_feed_keys = self.train_feed_keys valid_summary = tf.summary.merge([s_loss_summary_v, s_snr_summary_v]) self.valid_fetches = [ valid_summary, dict(loss=s_valid_loss, SNR=s_valid_snr) ] self.infer_feed_keys = [s_mixed_signals, s_dropout_keep] self.infer_fetches = dict(signals=s_separated_signals_infer) if hparams.DEBUG: self.debug_feed_keys = [s_src_signals, s_dropout_keep] self.debug_fetches = dict(embed=s_embed, attrs=s_attractors, input=s_src_signals, output=s_separated_signals) self.debug_fetches.update(encoder.debug_fetches) self.debug_fetches.update(separator.debug_fetches) if estimator is not None: self.debug_fetches.update(estimator.debug_fetches) self.saver = tf.train.Saver(var_list=v_params_li)
def build(self): # create sub-modules encoder = hparams.get_encoder()(self, 'encoder') # =================== # build the model input_shape = [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, None, hparams.FEATURE_SIZE ] s_src_signals = tf.placeholder(hparams.COMPLEXX, input_shape, name='source_signal') s_dropout_keep = tf.placeholder(hparams.FLOATX, [], name='dropout_keep') reger = hparams.get_regularizer() with tf.variable_scope('global', regularizer=reger): # TODO add mixing coeff ? # get mixed signal s_mixed_signals = tf.reduce_sum(s_src_signals, axis=1) s_src_signals_pwr = tf.abs(s_src_signals) s_mixed_signals_phase = tf.atan2(tf.imag(s_mixed_signals), tf.real(s_mixed_signals)) s_mixed_signals_power = tf.abs(s_mixed_signals) s_mixed_signals_log = tf.log1p(s_mixed_signals_power) # int[B, T, F] # float[B, T, F, E] s_embed = encoder(s_mixed_signals_log) s_embed_flat = tf.reshape( s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE]) s_embed_normalized = s_embed_flat * tf.rsqrt( tf.reduce_sum(tf.square(s_embed_flat), axis=-1, keep_dims=True) + hparams.EPS) s_src_signals_pwr_flat = tf.reshape( s_src_signals_pwr, [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1]) s_pwr_diff = s_src_signals_pwr_flat[:, 0] - s_src_signals_pwr_flat[:, 1] assign_thres = hparams.TRAIN_ASSIGN_THRES s_target_category = tf.cast( tf.stack([(s_pwr_diff > assign_thres), (s_pwr_diff < (-assign_thres))], axis=2), hparams.FLOATX) def mse_dot_loss(s_u_, s_v_): ''' Returns sum(square(dot(s_u_^T, s_v_))), batched ''' return tf.reduce_sum(tf.square( tf.matmul(tf.transpose(s_u_, [0, 2, 1]), s_v_)), axis=(1, 2)) # equation (3) # s_embed_normalized -> V # s_target_category -> Y s_train_loss = mse_dot_loss( s_embed_normalized, s_embed_normalized) + mse_dot_loss( s_target_category, s_target_category) - 2 * mse_dot_loss( s_embed_normalized, s_target_category) s_train_loss = tf.reduce_mean(s_train_loss) s_init_centers = tf.random_uniform( [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, hparams.EMBED_SIZE], -1., 1.) s_init_centers *= tf.rsqrt( tf.reduce_sum( tf.square(s_init_centers), axis=-1, keep_dims=True) + hparams.EPS) s_centers = ops.kmeans(s_embed_flat, s_init_centers, fn_step=ops.spherical_kmeans_step, max_step=hparams.MAX_KMEANS_ITERS) # estimate masks, and separated signals s_cosines = tf.matmul(s_embed_flat, tf.transpose(s_centers, [0, 2, 1])) s_assigns = tf.argmax(s_cosines, axis=2) s_masks = tf.one_hot(s_assigns, hparams.MAX_N_SIGNAL) s_masks = tf.transpose(s_masks, [0, 2, 1]) s_masks = tf.reshape(s_masks, [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE ]) s_separated_signals_pwr_valid = tf.expand_dims( s_mixed_signals_power, 1) * s_masks _, v_perms, s_perm_sets = ops.pit_mse_loss( s_src_signals_pwr, s_separated_signals_pwr_valid) s_mixed_signals_phase = tf.expand_dims(s_mixed_signals_phase, 1) s_separated_signals_valid = tf.complex( tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr_valid, tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr_valid) s_perm_idxs = tf.stack([ tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1), [1, hparams.MAX_N_SIGNAL]), tf.gather(v_perms, s_perm_sets) ], axis=2) s_perm_idxs = tf.reshape( s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2]) s_separated_signals_valid = tf.gather_nd(s_separated_signals_valid, s_perm_idxs) s_separated_signals_valid = tf.reshape(s_separated_signals_valid, [ hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1, hparams.FEATURE_SIZE ]) s_valid_snr = tf.reduce_mean( ops.batch_snr(s_src_signals, s_separated_signals_valid)) # =============== # prepare summary with tf.name_scope('train_summary'): s_loss_summary_t = tf.summary.scalar('loss', s_train_loss) s_snr_summary_t = tf.summary.scalar('SNR', s_valid_snr) with tf.name_scope('valid_summary'): s_snr_summary_v = tf.summary.scalar('SNR', s_valid_snr) # apply optimizer ozer = hparams.get_optimizer()(learn_rate=self.v_learn_rate, lr_decay=hparams.LR_DECAY) v_params_li = tf.trainable_variables() r_apply_grads = ozer.compute_gradients(s_train_loss, v_params_li) if hparams.GRAD_CLIP_THRES is not None: r_apply_grads = [(tf.clip_by_value(g, -hparams.GRAD_CLIP_THRES, hparams.GRAD_CLIP_THRES), v) for g, v in r_apply_grads if g is not None] self.op_sgd_step = ozer.apply_gradients(r_apply_grads) self.op_init_params = tf.variables_initializer(v_params_li) self.op_init_states = tf.variables_initializer( list(self.s_states_di.values())) self.train_feed_keys = [s_src_signals, s_dropout_keep] train_summary = tf.summary.merge([s_loss_summary_t, s_snr_summary_t]) self.train_fetches = [ train_summary, dict(loss=s_train_loss, SNR=s_valid_snr), self.op_sgd_step ] self.valid_feed_keys = self.train_feed_keys valid_summary = tf.summary.merge([s_snr_summary_v]) self.valid_fetches = [valid_summary, dict(SNR=s_valid_snr)] self.infer_feed_keys = [s_mixed_signals, s_dropout_keep] self.infer_fetches = dict(signals=s_separated_signals_valid) if hparams.DEBUG: self.debug_feed_keys = [s_src_signals, s_dropout_keep] self.debug_fetches = dict(embed=s_embed, centers=s_centers, input=s_src_signals, output=s_separated_signals_valid) self.debug_fetches.update(encoder.debug_fetches) self.saver = tf.train.Saver(var_list=v_params_li)