Пример #1
0
    def __init__(self, models, configs, beam_size, beta):
        """Sets some things up then calls _beam_search() to do the real work.

        Args:
            models: a sequence of RNN or Transformer objects.
            configs: a sequence of model configs (argparse.Namespace objects).
            beam_size: an integer specifying the beam width.
            beta: a float between 0.0 and 1.0 specifying the value of bias
        """
        self._models = models
        self._configs = configs
        self._beam_size = beam_size
        self.beta = beta

        with tf.name_scope('beam_search'):

            # Define placeholders.
            self.inputs = sampler_inputs.SamplerInputs()

            # Create model adapters to get a consistent interface to
            # Transformer and RNN models.
            model_adapters = []
            for i, (model, config) in enumerate(zip(models, configs)):
                with tf.name_scope('model_adapter_{}'.format(i)) as scope:
                    if config.model_type == 'transformer':
                        adapter = transformer_inference.ModelAdapter(
                            model, config, scope)
                    else:
                        assert config.model_type == 'rnn'
                        adapter = rnn_inference.ModelAdapter(
                            model, config, scope)
                    model_adapters.append(adapter)

            # Check that individual models are compatible with each other.
            vocab_sizes = [a.target_vocab_size for a in model_adapters]
            if len(set(vocab_sizes)) > 1:
                raise exception.Error('Cannot ensemble models with different '
                                      'target vocabulary sizes')
            target_vocab_size = vocab_sizes[0]

            # Build the graph to do the actual work.
            sequences, scores = _beam_search(
                model_adapters=model_adapters,
                beam_size=beam_size,
                batch_size_x=self.inputs.batch_size_x,
                max_translation_len=self.inputs.max_translation_len,
                normalization_alpha=self.inputs.normalization_alpha,
                vocab_size=target_vocab_size,
                eos_id=0,
                last_translation=self.inputs.last_translation,
                last_translation_len=self.inputs.last_translation_len,
                beta=self.beta)

            # print(sequences, scores)

            self._outputs = sequences, scores
Пример #2
0
    def __init__(self, models, configs, beam_size):
        """Sets some things up then calls _random_sample() to do the real work.

        Args:
            models: a sequence of RNN or Transformer objects.
            configs: a sequence of model configs (argparse.Namespace objects).
            beam_size: integer specifying the beam width.
        """
        self._models = models
        self._configs = configs
        self._beam_size = beam_size

        with tf.compat.v1.name_scope('random_sampler'):

            # Define placeholders.
            self.inputs = sampler_inputs.SamplerInputs()

            # Create an adapter to get a consistent interface to
            # Transformer and RNN models.
            model_adapters = []
            for i, (model, config) in enumerate(zip(models, configs)):
                with tf.compat.v1.name_scope(
                        'model_adapter_{}'.format(i)) as scope:
                    if config.model_type == 'transformer':
                        adapter = transformer_inference.ModelAdapter(
                            model, config, scope)
                    else:
                        assert config.model_type == 'rnn'
                        adapter = rnn_inference.ModelAdapter(
                            model, config, scope)
                    model_adapters.append(adapter)

            # Build the graph to do the actual work.
            sequences, scores = _random_sample(
                model_adapters=model_adapters,
                beam_size=beam_size,
                batch_size_x=self.inputs.batch_size_x,
                max_translation_len=self.inputs.max_translation_len,
                normalization_alpha=self.inputs.normalization_alpha,
                eos_id=0)

            self._outputs = sequences, scores