예제 #1
0
    def pooling_layer(self, x, pooling_type=None):
        '''
      Add a pooling layer across the whole utterance.
      Input: [B, T, D]
        --> Reduce along T

      Statistics pooling output: [B, D * 2]
      Average pooling output: [B, D]
    '''
        assert_rank3 = tf.debugging.assert_rank(x, 3)
        with tf.control_dependencies([assert_rank3]):
            x = tf.identity(x)

        pooling_type = pooling_type if pooling_type else self.netconf[
            'frame_pooling_type']
        if pooling_type == 'stats':
            with tf.name_scope('stats_pooling'):
                mean, var = tf.nn.moments(x, 1)
                x = tf.concat([mean, tf.sqrt(var + 1e-6)], 1)
        elif pooling_type == 'average':
            with tf.name_scope('average_pooling'):
                mean, _ = tf.nn.moments(x, 1)
                x = mean
        else:
            raise ValueError('Unsupported frame_pooling_type: %s' %
                             (pooling_type))

        assert_rank2 = tf.debugging.assert_rank(x, 2)
        with tf.control_dependencies([assert_rank2]):
            x = tf.identity(x)

        return x
예제 #2
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate fbank features of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: A float tensor of size (num_channels, num_frames, num_frequencies) containing
            fbank features of every frame in speech.
    """
        p = self.config
        with tf.name_scope('fbank'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            if p.upper_frequency_limit <= 0:
                p.upper_frequency_limit = p.sample_rate / 2.0 + p.upper_frequency_limit
            elif (p.upper_frequency_limit <= p.lower_frequency_limit) or (
                    p.upper_frequency_limit > p.sample_rate / 2.0):
                p.upper_frequency_limit = p.sample_rate / 2.0

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                spectrum = self.spect(audio_data, sample_rate)
                spectrum = tf.expand_dims(spectrum, 0)

                fbank = py_x_ops.fbank(
                    spectrum,
                    sample_rate,
                    upper_frequency_limit=p.upper_frequency_limit,
                    lower_frequency_limit=p.lower_frequency_limit,
                    filterbank_channel_count=p.filterbank_channel_count)

                return fbank
예제 #3
0
  def call(self, audio_data, sample_rate=None):
    """
    Caculate cepstrum of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return:A float tensor of size (num_frames, ceps_subband_num) containing normalized cepstrum
          (tag_ceps_mean_norm = True) or cepstrum (tag_ceps_mean_norm = False) of every frame in speech.
    """

    p = self.config

    with tf.name_scope('cepstrum'):

      if sample_rate == None:
        sample_rate = tf.constant(p.sample_rate, dtype=float)

      assert_op = tf.assert_equal(
          tf.constant(p.sample_rate), tf.cast(sample_rate, dtype=float))
      with tf.control_dependencies([assert_op]):

        cepstrum = py_x_ops.cepstrum(
            audio_data,
            sample_rate,
            window_length=p.window_length,
            frame_length=p.frame_length,
            ceps_subband_num=p.ceps_subband_num,
            tag_ceps_mean_norm=p.tag_ceps_mean_norm)

        return cepstrum
예제 #4
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate mfcc features of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: A float tensor of size (num_channels, num_frames, num_frequencies) containing
            mfcc features of every frame in speech.
    """
        p = self.config
        with tf.name_scope('mfcc'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                spectrum_feats = self.spect(audio_data, sample_rate)
                spectrum_feats = tf.expand_dims(spectrum_feats, 0)
                fbank_feats = self.fbank(audio_data, sample_rate)
                mfcc = py_x_ops.mfcc(fbank_feats,
                                     spectrum_feats,
                                     sample_rate,
                                     use_energy=p.use_energy,
                                     cepstral_lifter=p.cepstral_lifter,
                                     coefficient_count=p.coefficient_count)
                return mfcc
예제 #5
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate power spectrum and phase spectrum of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: Two returns:
        power spectrum —— A float tensor of size (num_frames, num_frequencies) containing
            power spectrum and of every frame in speech.
        phase spectrum —— A float tensor of size (num_frames, num_frequencies) containing
            phase spectrum and of every frame in speech.
    """

        p = self.config
        with tf.name_scope('analyfiltbank'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                sample_rate = tf.cast(sample_rate, dtype=float)
                power_spectrum, phase_spectrum = py_x_ops.analyfiltbank(
                    audio_data,
                    sample_rate,
                    window_length=p.window_length,
                    frame_length=p.frame_length)

                return power_spectrum, phase_spectrum
예제 #6
0
  def call(self, audio_data, sample_rate=None):
    """
        Caculate power spectrum or log power spectrum of audio data.
        :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
        :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
        :return: A float tensor of size N containing add-noise audio.
        """

    p = self.config
    with tf.name_scope('add_rir_noise_aecres'):
      if sample_rate == None:
        sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

      assert_op = tf.assert_equal(
          tf.constant(p.sample_rate), tf.cast(sample_rate, dtype=tf.int32))
      with tf.control_dependencies([assert_op]):
        sample_rate = tf.cast(sample_rate, dtype=float)
        add_rir_noise_aecres_out = py_x_ops.add_rir_noise_aecres(
            audio_data,
            sample_rate,
            if_add_rir=p.if_add_rir,
            rir_filelist=p.rir_filelist,
            if_add_noise=p.if_add_noise,
            snr_min=p.snr_min,
            snr_max=p.snr_max,
            noise_filelist=p.noise_filelist,
            if_add_aecres=p.if_add_aecres,
            aecres_filelist=p.aecres_filelist)

        return tf.squeeze(add_rir_noise_aecres_out)
예제 #7
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate fbank && pitch(concat) features of wav.
    :param audio_data: the audio signal from which to compute spectrum.
                       Should be an (1, N) tensor.
    :param sample_rate: the samplerate of the signal we working with.
    :return: A tensor with shape (num_frames, dim_features), containing
            fbank && pitch feature of every frame in speech.
    """

        p = self.config
        with tf.name_scope('fbank_pitch'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                fbank_feats = tf.squeeze(self.fbank(audio_data, sample_rate))
                pitch_feats = tf.squeeze(self.pitch(audio_data, sample_rate))
                fbank_pitch_feats = tf.concat([fbank_feats, pitch_feats], 1)

                return fbank_pitch_feats
예제 #8
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate plp features of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return:A float tensor of size (num_frames, (plp_order + 1)) containing plp features of every frame in speech.
    """

        p = self.config
        with tf.name_scope('plp'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                sample_rate = tf.cast(sample_rate, dtype=float)
                plp = py_x_ops.plp(audio_data,
                                   sample_rate,
                                   window_length=p.window_length,
                                   frame_length=p.frame_length,
                                   plp_order=p.plp_order)
                return plp
예제 #9
0
  def train(self):  # pylint: disable=too-many-locals
    """Train the model."""
    mode = utils.TRAIN
    train_model = self.build(mode)

    # Supervisor
    with tf.name_scope("train"):
      global_step = tf.train.get_or_create_global_step()
      train_op = self.get_train_op(train_model.loss_op, global_step)

      checkpoint_dir = get_checkpoint_dir(self.config)

      # scaffold
      scaffold = self.get_scaffold(mode, global_step,
                                   train_model.iterator.initializer)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=checkpoint_dir,
        scaffold=scaffold,
        save_checkpoint_steps=self.save_checkpoint_steps,
        config=self.session_conf) as sess:
      # Training loop. For each batch...
      data_size = self.config['data']['train_data_size']
      num_epochs = self.config["data"]["task"]['epochs']
      num_batch = int(math.ceil(data_size * num_epochs / self.batch_size))
      num_batch_per_epoch = int(data_size / self.batch_size)
      logging.info(
          "num_batch: {}, num_batch_per_epoch: {}, num_epochs: {}".format(
              num_batch, num_batch_per_epoch, num_epochs))
      for i in range(num_batch):
        _, _, out_loss = sess.run([train_op, global_step, train_model.loss_op])
        if i % self.print_every == 0 or i == num_batch - 1:
          logging.info("Training for epoch {}: [ {:.2%} ] loss is {:g}".format(
              int(i / num_batch_per_epoch),
              (i % num_batch_per_epoch) / num_batch_per_epoch, out_loss))
예제 #10
0
 def get_pre_train_graph(self, inputs):
     pretrained_model_meta = self.pretrained_model_path + '.meta'
     seg_idx = self.pretrained_model_seg
     pad_idx = self.pretrained_model_pad
     with tf.name_scope('pretrain_graph') as scope:
         pretrained_graph = tf.get_default_graph()
         if self.pretrained_model_name == 'elmo':
             pretrained_saver = tf.train.import_meta_graph(
                 pretrained_model_meta, input_map={'input_x:0': inputs})
             input_x_pretrained = pretrained_graph.get_tensor_by_name(
                 scope + 'input_x_elmo:0')
         if self.pretrained_model_name == 'bert':
             pad_mask = get_pad_mask_from_token_idx(inputs, pad_idx)
             segment_mask = get_seg_mask_from_token_idx(inputs, seg_idx)
             pretrained_saver = tf.train.import_meta_graph(
                 pretrained_model_meta,
                 input_map={
                     'input_ids:0': inputs,
                     'input_mask:0': pad_mask,
                     'segment_ids:0': segment_mask
                 })
             if self.pretrained_model_output == 'seq':
                 input_x_pretrained = \
                   pretrained_graph.get_tensor_by_name(scope + 'encoder_layers_{}:0'.
                                                       format(self.pretrained_model_layers))
             else:
                 input_x_pretrained = pretrained_graph.get_tensor_by_name(
                     scope + 'input_x_bert_cls:0')
         return input_x_pretrained
예제 #11
0
    def train_and_eval(self):  # pylint: disable=too-many-locals
        """Train and evaluate the model."""
        # train related
        g_train = tf.Graph()
        with g_train.as_default():
            logging.info("Compiling train model ...")
            train_model = self.build(utils.TRAIN)
        # eval related
        g_eval = tf.Graph()
        with g_eval.as_default():
            logging.info("Compiling eval model ...")
            eval_model = self.build(utils.EVAL)
            eval_model.sess = tf.Session(config=self.session_conf,
                                         graph=g_eval)
            eval_model.saver = tf.train.Saver()

        # start train
        with g_train.as_default():
            # Supervisor
            with tf.name_scope("train"):
                global_step = tf.train.get_or_create_global_step()

                train_op = self.get_train_op(train_model.loss_op, global_step)

                checkpoint_dir = get_checkpoint_dir(self.config)

                # scaffold
                scaffold = self.get_scaffold(utils.TRAIN, global_step,
                                             train_model.iterator.initializer)

                with tf.train.MonitoredTrainingSession(
                        checkpoint_dir=checkpoint_dir,
                        scaffold=scaffold,
                        save_checkpoint_steps=self.save_checkpoint_steps,
                        config=self.session_conf) as sess:
                    # Training loop. For each batch...
                    train_data_size = self.config['data']['train_data_size']
                    num_batch = math.ceil(train_data_size * self.num_epochs /
                                          self.batch_size)
                    num_batch_per_epoch = math.ceil(train_data_size /
                                                    self.batch_size)
                    logging.info("Total data size: {}, batch num: {}, "
                                 "batch num per epoch: {}".format(
                                     train_data_size, num_batch,
                                     num_batch_per_epoch))
                    for i in range(0, num_batch):

                        if i % self.save_checkpoint_steps == 0 and i != 0:
                            self.eval_or_infer_core(eval_model, utils.EVAL)
                        _, _, out_loss = sess.run(
                            [train_op, global_step, train_model.loss_op])
                        if i % self.print_every == 0 or i == num_batch - 1 or (
                                i + 1
                        ) % num_batch_per_epoch == 0 or i % num_batch_per_epoch == 0:
                            logging.info(
                                "Training for epoch {}: [ {:.2%} ] loss is {:g}"
                                .format(int(i / num_batch_per_epoch),
                                        (i % num_batch_per_epoch) /
                                        num_batch_per_epoch, out_loss))
        eval_model.sess.close()
예제 #12
0
    def call(self, power_spectrum, phase_spectrum, sample_rate=None):
        """
    Implement frequency domain to time domain conversion.
    :param power_spectrum: a float tensor of size (num_frames, num_frequencies).
    :param phase_spectrum: a float tensor of size (num_frames, num_frequencies).
    :param sample_rate: a scalar tensor.
    :return: audio data
    """

        p = self.config
        with tf.name_scope('synthfiltbank'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                audio_data = py_x_ops.synthfiltbank(
                    power_spectrum,
                    phase_spectrum,
                    sample_rate,
                    window_length=p.window_length,
                    frame_length=p.frame_length)

                return audio_data
예제 #13
0
    def call(self, audio_data, sample_rate=None):
        """
    Calculate the zero-crossing rate of speech.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: A tensor with shape (1, num_frames), containing zero-crossing rate of every frame in speech.
    """

        p = self.config
        with tf.name_scope('zcr'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                sample_rate = tf.cast(sample_rate, dtype=float)
                zcr = py_x_ops.zcr(audio_data,
                                   sample_rate,
                                   window_length=p.window_length,
                                   frame_length=p.frame_length)

                return zcr
예제 #14
0
  def call(self, audio_data, sample_rate=None):
    """
        Caculate power of every frame in speech.
        :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
        :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
        :return:A float tensor of size (1, num_frames) containing power of every frame in speech.
        """

    p = self.config
    with tf.name_scope('framepow'):

      if sample_rate == None:
        sample_rate = tf.constant(p.sample_rate, dtype=float)

      assert_op = tf.assert_equal(
          tf.constant(p.sample_rate), tf.cast(sample_rate, dtype=float))
      with tf.control_dependencies([assert_op]):

        framepow = py_x_ops.frame_pow(
            audio_data,
            sample_rate,
            window_length=p.window_length,
            frame_length=p.frame_length)

        return framepow
예제 #15
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate power spectrum or log power spectrum of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: A float tensor of size (num_frames, num_frequencies) containing power spectrum (output_type=1)
        or log power spectrum (output_type=2) of every frame in speech.
    """

        p = self.config
        with tf.name_scope('spectrum'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=tf.int32))
            with tf.control_dependencies([assert_op]):

                sample_rate = tf.cast(sample_rate, dtype=float)
                spectrum = py_x_ops.spectrum(
                    audio_data,
                    sample_rate,
                    window_length=p.window_length,
                    frame_length=p.frame_length,
                    output_type=p.output_type,
                    snip_edges=p.snip_edges,
                    raw_energy=p.raw_energy,
                    preEph_coeff=p.preeph_coeff,
                    window_type=p.window_type,
                    remove_dc_offset=p.remove_dc_offset,
                    is_fbank=p.is_fbank)

                return spectrum
예제 #16
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate pitch features of audio data.
    :param audio_data: the audio signal from which to compute spectrum. Should be an (1, N) tensor.
    :param sample_rate: [option]the samplerate of the signal we working with, default is 16kHz.
    :return: A float tensor of size (1, num_frames) containing pitch features of every frame in speech.
    """

        p = self.config
        with tf.name_scope('pitch'):

            if sample_rate == None:
                sample_rate = tf.constant(p.sample_rate, dtype=float)

            assert_op = tf.assert_equal(tf.constant(p.sample_rate),
                                        tf.cast(sample_rate, dtype=float))
            with tf.control_dependencies([assert_op]):

                pitch = py_x_ops.pitch(audio_data,
                                       sample_rate,
                                       window_length=p.window_length,
                                       frame_length=p.frame_length,
                                       thres_autoc=p.thres_autoc)

                pitch = tf.squeeze(pitch)
                pitch = tf.transpose(pitch[None, :])
                return pitch
예제 #17
0
    def call(self, audio_data, sample_rate=None):
        """
    Caculate picth features of audio data.
    :param audio_data: the audio signal from which to compute spectrum.
                      Should be an (1, N) tensor.
    :param sample_rate: the samplerate of the signal we working with.
    :return: A float tensor of size (num_frames, 2) containing
           pitch && POV features of every frame in speech.
    """
        p = self.config

        with tf.name_scope('pitch'):

            if sample_rate is None:
                sample_rate = tf.constant(p.sample_rate, dtype=tf.int32)
            else:
                if not tf.is_tensor(sample_rate):
                    sample_rate = tf.convert_to_tensor(sample_rate)

            pitch = py_x_ops.pitch(
                audio_data,
                sample_rate,
                window_length=p.window_length,
                frame_length=p.frame_length,
                snip_edges=p.snip_edges,
                preemph_coeff=p.preemph_coeff,
                min_f0=p.min_f0,
                max_f0=p.max_f0,
                soft_min_f0=p.soft_min_f0,
                penalty_factor=p.penalty_factor,
                lowpass_cutoff=p.lowpass_cutoff,
                resample_freq=p.resample_freq,
                delta_pitch=p.delta_pitch,
                nccf_ballast=p.nccf_ballast,
                lowpass_filter_width=p.lowpass_filter_width,
                upsample_filter_width=p.upsample_filter_width,
                max_frames_latency=p.max_frames_latency,
                frames_per_chunk=p.frames_per_chunk,
                simulate_first_pass_online=p.simulate_first_pass_online,
                recompute_frame=p.recompute_frame,
                nccf_ballast_online=p.nccf_ballast_online,
                pitch_scale=p.pitch_scale,
                pov_scale=p.pov_scale,
                pov_offset=p.pov_offset,
                delta_pitch_scale=p.delta_pitch_scale,
                delta_pitch_noise_stddev=p.delta_pitch_noise_stddev,
                normalization_left_context=p.normalization_left_context,
                normalization_right_context=p.normalization_right_context,
                delta_window=p.delta_window,
                delay=p.delay,
                add_pov_feature=p.add_pov_feature,
                add_normalized_log_pitch=p.add_normalized_log_pitch,
                add_delta_pitch=p.add_delta_pitch,
                add_raw_log_pitch=p.add_raw_log_pitch)

            return pitch
예제 #18
0
 def l2_loss(self, tvars=None):
     _l2_loss = 0.0
     weight_decay = self.config['solver']['optimizer'].get(
         'weight_decay', None)
     if weight_decay:
         logging.info(f"add L2 Loss with decay: {weight_decay}")
         with tf.name_scope('l2_loss'):
             tvars = tvars if tvars else tf.trainable_variables()
             tvars = [v for v in tvars if 'bias' not in v.name]
             _l2_loss = weight_decay * tf.add_n(
                 [tf.nn.l2_loss(v) for v in tvars])
             summary_lib.scalar('l2_loss', _l2_loss)
     return _l2_loss
예제 #19
0
    def call(self, in_wavfile, out_wavfile):
        """
        Read a clean wav return a noisy wav.
        :param in_wavfile: clean wavfile path.
        :param out_wavfile: noisy wavfile path.
        :return: write wav opration.
        """

        with tf.name_scope('add_noise_end_to_end'):
            input_data, sample_rate = self.read_wav(in_wavfile)
            noisy_data = self.add_noise(input_data, sample_rate) / 32768
            write_op = self.write_wav(out_wavfile, noisy_data, sample_rate)

        return write_op
예제 #20
0
  def call(self, feat, order, window):
    """
    Caculate delta of feats.
    :param feat: a float tensor of size (num_frames, dim_feat).
    :param order: an int.
    :param window: an int.
    :return: A tensor with shape (num_frames, (dim_feat * (order + 1))),
        containing delta of features of every frame in speech.
    """

    p = self.config
    with tf.name_scope('delta_delta'):
      delta_delta = py_x_ops.delta_delta(feat, order, window)

    return delta_delta
예제 #21
0
def accuracy(logits, labels):
    ''' accuracy candies
  params:
    logits: [B, ..., D]
    labels: [B, ...]
  return:
    accuracy tensor
  '''
    with tf.name_scope('accuracy'):
        assert_rank = tf.assert_equal(tf.rank(logits), tf.rank(labels) + 1)
        assert_shape = tf.assert_equal(tf.shape(logits)[:-1], tf.shape(labels))
        with tf.control_dependencies([assert_rank, assert_shape]):
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int64)
            labels = tf.cast(labels, tf.int64)
            return tf.reduce_mean(
                tf.cast(tf.equal(predictions, labels), dtype=tf.float32))
예제 #22
0
    def call(self, inps, training=None, mask=None):
        if not self.is_infer:
            dec_inp, enc_out = inps
            with tf.name_scope('while'):
                dec_out = self.decode(dec_inp, enc_out, training, mask)
                scores = self.final_dense(dec_out)
                return scores
        else:
            enc_out = inps
            init_ids = tf.cast(
                tf.ones([utils.shape_list(enc_out)[0]]) * self.sos_id,
                tf.int32)
            # Beam Search
            enc_shape = utils.shape_list(enc_out)
            enc_out = tf.tile(tf.expand_dims(enc_out, axis=1),
                              [1, self.beam_size, 1, 1])
            enc_out = tf.reshape(
                enc_out,
                [enc_shape[0] * self.beam_size, enc_shape[1], enc_shape[2]])
            enc_mask = tf.tile(tf.expand_dims(mask, axis=1),
                               [1, self.beam_size, 1, 1, 1])
            enc_mask = tf.reshape(enc_mask,
                                  [enc_shape[0] * self.beam_size, 1, 1, -1])

            def symbols_to_logits_fn(dec_inps):
                dec_out = self.decode(dec_inps, enc_out, training, enc_mask)
                scores = self.final_dense(dec_out)
                return scores[:, -1, :]

            decoded_ids, scores, _ = self.beam_search(symbols_to_logits_fn,
                                                      init_ids, self.beam_size,
                                                      self.max_dec_len,
                                                      self.vocab_size,
                                                      self.length_penalty,
                                                      self.eos_id)
            decoded_ids = decoded_ids[:, 0, 1:]

            return decoded_ids