Пример #1
0
    def build(self):
        # create sub-modules
        encoder = hparams.get_encoder()(self, 'encoder')
        # ===================
        # build the model

        input_shape = [
            hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, None,
            hparams.FEATURE_SIZE
        ]

        s_src_signals = tf.placeholder(hparams.COMPLEXX,
                                       input_shape,
                                       name='source_signal')
        s_dropout_keep = tf.placeholder(hparams.FLOATX, [],
                                        name='dropout_keep')
        reger = hparams.get_regularizer()
        with tf.variable_scope('global', regularizer=reger):
            # TODO add mixing coeff ?

            # get mixed signal
            s_mixed_signals = tf.reduce_sum(s_src_signals, axis=1)

            s_src_signals_pwr = tf.abs(s_src_signals)
            s_mixed_signals_phase = tf.atan2(tf.imag(s_mixed_signals),
                                             tf.real(s_mixed_signals))
            s_mixed_signals_power = tf.abs(s_mixed_signals)
            s_mixed_signals_log = tf.log1p(s_mixed_signals_power)
            # int[B, T, F]
            # float[B, T, F, E]
            s_embed = encoder(s_mixed_signals_log)
            s_embed_flat = tf.reshape(
                s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])

            # TODO make attractor estimator a submodule ?
            estimator = hparams.get_estimator(hparams.TRAIN_ESTIMATOR_METHOD)(
                self, 'train_estimator')
            s_attractors = estimator(s_embed,
                                     s_src_pwr=s_src_signals_pwr,
                                     s_mix_pwr=s_mixed_signals_power)

            using_same_method = (hparams.INFER_ESTIMATOR_METHOD ==
                                 hparams.TRAIN_ESTIMATOR_METHOD)

            if using_same_method:
                s_valid_attractors = s_attractors
            else:
                valid_estimator = hparams.get_estimator(
                    hparams.INFER_ESTIMATOR_METHOD)(self, 'infer_estimator')
                assert not valid_estimator.USE_TRUTH
                s_valid_attractors = valid_estimator(s_embed)

            separator = hparams.get_separator(hparams.SEPARATOR_TYPE)(
                self, 'separator')
            s_separated_signals_pwr = separator(s_mixed_signals_power,
                                                s_attractors, s_embed_flat)

            if using_same_method:
                s_separated_signals_pwr_valid = s_separated_signals_pwr
            else:
                s_separated_signals_pwr_valid = separator(
                    s_mixed_signals_power, s_valid_attractors, s_embed_flat)

            # use mixture phase and estimated power to get separated signal
            s_mixed_signals_phase = tf.expand_dims(s_mixed_signals_phase, 1)
            s_separated_signals = tf.complex(
                tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr,
                tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr)

            # loss and SNR for training
            # s_train_loss, v_perms, s_perm_sets = ops.pit_mse_loss(
            # s_src_signals_pwr, s_separated_signals_pwr)
            s_train_loss, v_perms, s_perm_sets = ops.pit_mse_loss(
                s_src_signals, s_separated_signals)

            # resolve permutation
            s_perm_idxs = tf.stack([
                tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1),
                        [1, hparams.MAX_N_SIGNAL]),
                tf.gather(v_perms, s_perm_sets)
            ],
                                   axis=2)
            s_perm_idxs = tf.reshape(
                s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2])
            s_separated_signals = tf.gather_nd(s_separated_signals,
                                               s_perm_idxs)
            s_separated_signals = tf.reshape(s_separated_signals, [
                hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1,
                hparams.FEATURE_SIZE
            ])

            s_train_snr = tf.reduce_mean(
                ops.batch_snr(s_src_signals, s_separated_signals))

            # ^ for validation / inference
            s_valid_loss, v_perms, s_perm_sets = ops.pit_mse_loss(
                s_src_signals_pwr, s_separated_signals_pwr_valid)
            s_perm_idxs = tf.stack([
                tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1),
                        [1, hparams.MAX_N_SIGNAL]),
                tf.gather(v_perms, s_perm_sets)
            ],
                                   axis=2)
            s_perm_idxs = tf.reshape(
                s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2])
            s_separated_signals_pwr_valid_pit = tf.gather_nd(
                s_separated_signals_pwr_valid, s_perm_idxs)
            s_separated_signals_pwr_valid_pit = tf.reshape(
                s_separated_signals_pwr_valid_pit, [
                    hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1,
                    hparams.FEATURE_SIZE
                ])

            s_separated_signals_valid = tf.complex(
                tf.cos(s_mixed_signals_phase) *
                s_separated_signals_pwr_valid_pit,
                tf.sin(s_mixed_signals_phase) *
                s_separated_signals_pwr_valid_pit)
            s_separated_signals_infer = tf.complex(
                tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr_valid,
                tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr_valid)
            s_valid_snr = tf.reduce_mean(
                ops.batch_snr(s_src_signals, s_separated_signals_valid))

        # ===============
        # prepare summary
        # TODO add impl & summary for word error rate
        with tf.name_scope('train_summary'):
            s_loss_summary_t = tf.summary.scalar('loss', s_train_loss)
            s_snr_summary_t = tf.summary.scalar('SNR', s_train_snr)

        with tf.name_scope('valid_summary'):
            s_loss_summary_v = tf.summary.scalar('loss', s_valid_loss)
            s_snr_summary_v = tf.summary.scalar('SNR', s_valid_snr)

        # apply optimizer
        ozer = hparams.get_optimizer()(learn_rate=self.v_learn_rate,
                                       lr_decay=hparams.LR_DECAY)

        v_params_li = tf.trainable_variables()
        r_apply_grads = ozer.compute_gradients(s_train_loss, v_params_li)
        if hparams.GRAD_CLIP_THRES is not None:
            r_apply_grads = [(tf.clip_by_value(g, -hparams.GRAD_CLIP_THRES,
                                               hparams.GRAD_CLIP_THRES), v)
                             for g, v in r_apply_grads if g is not None]
        self.op_sgd_step = ozer.apply_gradients(r_apply_grads)

        self.op_init_params = tf.variables_initializer(v_params_li)
        self.op_init_states = tf.variables_initializer(
            list(self.s_states_di.values()))

        self.train_feed_keys = [s_src_signals, s_dropout_keep]
        train_summary = tf.summary.merge([s_loss_summary_t, s_snr_summary_t])
        self.train_fetches = [
            train_summary,
            dict(loss=s_train_loss, SNR=s_train_snr), self.op_sgd_step
        ]

        self.valid_feed_keys = self.train_feed_keys
        valid_summary = tf.summary.merge([s_loss_summary_v, s_snr_summary_v])
        self.valid_fetches = [
            valid_summary,
            dict(loss=s_valid_loss, SNR=s_valid_snr)
        ]

        self.infer_feed_keys = [s_mixed_signals, s_dropout_keep]
        self.infer_fetches = dict(signals=s_separated_signals_infer)

        if hparams.DEBUG:
            self.debug_feed_keys = [s_src_signals, s_dropout_keep]
            self.debug_fetches = dict(embed=s_embed,
                                      attrs=s_attractors,
                                      input=s_src_signals,
                                      output=s_separated_signals)
            self.debug_fetches.update(encoder.debug_fetches)
            self.debug_fetches.update(separator.debug_fetches)
            if estimator is not None:
                self.debug_fetches.update(estimator.debug_fetches)

        self.saver = tf.train.Saver(var_list=v_params_li)
Пример #2
0
    def build(self):
        # create sub-modules
        encoder = hparams.get_encoder()(self, 'encoder')
        # ===================
        # build the model

        input_shape = [
            hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, None,
            hparams.FEATURE_SIZE
        ]

        s_src_signals = tf.placeholder(hparams.COMPLEXX,
                                       input_shape,
                                       name='source_signal')
        s_dropout_keep = tf.placeholder(hparams.FLOATX, [],
                                        name='dropout_keep')
        reger = hparams.get_regularizer()
        with tf.variable_scope('global', regularizer=reger):
            # TODO add mixing coeff ?

            # get mixed signal
            s_mixed_signals = tf.reduce_sum(s_src_signals, axis=1)

            s_src_signals_pwr = tf.abs(s_src_signals)
            s_mixed_signals_phase = tf.atan2(tf.imag(s_mixed_signals),
                                             tf.real(s_mixed_signals))
            s_mixed_signals_power = tf.abs(s_mixed_signals)
            s_mixed_signals_log = tf.log1p(s_mixed_signals_power)
            # int[B, T, F]
            # float[B, T, F, E]
            s_embed = encoder(s_mixed_signals_log)
            s_embed_flat = tf.reshape(
                s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])
            s_embed_normalized = s_embed_flat * tf.rsqrt(
                tf.reduce_sum(tf.square(s_embed_flat), axis=-1, keep_dims=True)
                + hparams.EPS)

            s_src_signals_pwr_flat = tf.reshape(
                s_src_signals_pwr,
                [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1])
            s_pwr_diff = s_src_signals_pwr_flat[:,
                                                0] - s_src_signals_pwr_flat[:,
                                                                            1]
            assign_thres = hparams.TRAIN_ASSIGN_THRES
            s_target_category = tf.cast(
                tf.stack([(s_pwr_diff > assign_thres),
                          (s_pwr_diff < (-assign_thres))],
                         axis=2), hparams.FLOATX)

            def mse_dot_loss(s_u_, s_v_):
                '''
                Returns sum(square(dot(s_u_^T, s_v_))), batched
                '''
                return tf.reduce_sum(tf.square(
                    tf.matmul(tf.transpose(s_u_, [0, 2, 1]), s_v_)),
                                     axis=(1, 2))

            # equation (3)
            # s_embed_normalized -> V
            # s_target_category -> Y
            s_train_loss = mse_dot_loss(
                s_embed_normalized, s_embed_normalized) + mse_dot_loss(
                    s_target_category, s_target_category) - 2 * mse_dot_loss(
                        s_embed_normalized, s_target_category)
            s_train_loss = tf.reduce_mean(s_train_loss)

            s_init_centers = tf.random_uniform(
                [hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, hparams.EMBED_SIZE],
                -1., 1.)
            s_init_centers *= tf.rsqrt(
                tf.reduce_sum(
                    tf.square(s_init_centers), axis=-1, keep_dims=True) +
                hparams.EPS)

            s_centers = ops.kmeans(s_embed_flat,
                                   s_init_centers,
                                   fn_step=ops.spherical_kmeans_step,
                                   max_step=hparams.MAX_KMEANS_ITERS)

            # estimate masks, and separated signals
            s_cosines = tf.matmul(s_embed_flat,
                                  tf.transpose(s_centers, [0, 2, 1]))
            s_assigns = tf.argmax(s_cosines, axis=2)
            s_masks = tf.one_hot(s_assigns, hparams.MAX_N_SIGNAL)
            s_masks = tf.transpose(s_masks, [0, 2, 1])
            s_masks = tf.reshape(s_masks, [
                hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1,
                hparams.FEATURE_SIZE
            ])

            s_separated_signals_pwr_valid = tf.expand_dims(
                s_mixed_signals_power, 1) * s_masks
            _, v_perms, s_perm_sets = ops.pit_mse_loss(
                s_src_signals_pwr, s_separated_signals_pwr_valid)

            s_mixed_signals_phase = tf.expand_dims(s_mixed_signals_phase, 1)
            s_separated_signals_valid = tf.complex(
                tf.cos(s_mixed_signals_phase) * s_separated_signals_pwr_valid,
                tf.sin(s_mixed_signals_phase) * s_separated_signals_pwr_valid)

            s_perm_idxs = tf.stack([
                tf.tile(tf.expand_dims(tf.range(hparams.BATCH_SIZE), 1),
                        [1, hparams.MAX_N_SIGNAL]),
                tf.gather(v_perms, s_perm_sets)
            ],
                                   axis=2)
            s_perm_idxs = tf.reshape(
                s_perm_idxs, [hparams.BATCH_SIZE * hparams.MAX_N_SIGNAL, 2])
            s_separated_signals_valid = tf.gather_nd(s_separated_signals_valid,
                                                     s_perm_idxs)
            s_separated_signals_valid = tf.reshape(s_separated_signals_valid, [
                hparams.BATCH_SIZE, hparams.MAX_N_SIGNAL, -1,
                hparams.FEATURE_SIZE
            ])

            s_valid_snr = tf.reduce_mean(
                ops.batch_snr(s_src_signals, s_separated_signals_valid))

        # ===============
        # prepare summary

        with tf.name_scope('train_summary'):
            s_loss_summary_t = tf.summary.scalar('loss', s_train_loss)
            s_snr_summary_t = tf.summary.scalar('SNR', s_valid_snr)

        with tf.name_scope('valid_summary'):
            s_snr_summary_v = tf.summary.scalar('SNR', s_valid_snr)

        # apply optimizer
        ozer = hparams.get_optimizer()(learn_rate=self.v_learn_rate,
                                       lr_decay=hparams.LR_DECAY)

        v_params_li = tf.trainable_variables()
        r_apply_grads = ozer.compute_gradients(s_train_loss, v_params_li)
        if hparams.GRAD_CLIP_THRES is not None:
            r_apply_grads = [(tf.clip_by_value(g, -hparams.GRAD_CLIP_THRES,
                                               hparams.GRAD_CLIP_THRES), v)
                             for g, v in r_apply_grads if g is not None]
        self.op_sgd_step = ozer.apply_gradients(r_apply_grads)

        self.op_init_params = tf.variables_initializer(v_params_li)
        self.op_init_states = tf.variables_initializer(
            list(self.s_states_di.values()))

        self.train_feed_keys = [s_src_signals, s_dropout_keep]
        train_summary = tf.summary.merge([s_loss_summary_t, s_snr_summary_t])
        self.train_fetches = [
            train_summary,
            dict(loss=s_train_loss, SNR=s_valid_snr), self.op_sgd_step
        ]

        self.valid_feed_keys = self.train_feed_keys
        valid_summary = tf.summary.merge([s_snr_summary_v])
        self.valid_fetches = [valid_summary, dict(SNR=s_valid_snr)]

        self.infer_feed_keys = [s_mixed_signals, s_dropout_keep]
        self.infer_fetches = dict(signals=s_separated_signals_valid)

        if hparams.DEBUG:
            self.debug_feed_keys = [s_src_signals, s_dropout_keep]
            self.debug_fetches = dict(embed=s_embed,
                                      centers=s_centers,
                                      input=s_src_signals,
                                      output=s_separated_signals_valid)
            self.debug_fetches.update(encoder.debug_fetches)

        self.saver = tf.train.Saver(var_list=v_params_li)