Exemplo n.º 1
0
        def virtual_adversarial_loss():
            """Computes virtual adversarial loss.

      Uses lm_inputs and constructs the language model graph if it hasn't yet
      been constructed.

      Also ensures that the LM input states are saved for LSTM state-saving
      BPTT.

      Returns:
        loss: float scalar.
      """
            if self.lm_inputs is None:
                self.language_model_graph(compute_loss=False)

            def logits_from_embedding(embedded, return_next_state=False):
                _, next_states, logits, _ = self.cl_loss_from_embedding(
                    embedded, inputs=self.lm_inputs, return_intermediates=True)
                if return_next_state:
                    return next_states, logits
                else:
                    return logits

            lm_embedded = (self.tensors['lm_embedded'],
                           self.tensors['lm_embedded_reverse'])
            next_states, lm_cl_logits = logits_from_embedding(
                lm_embedded, return_next_state=True)

            va_loss = adv_lib.virtual_adversarial_loss_bidir(
                lm_cl_logits, lm_embedded, self.lm_inputs,
                logits_from_embedding)

            saves = [
                inp.save_state(state)
                for (inp, state) in zip(self.lm_inputs, next_states)
            ]
            with tf.control_dependencies(saves):
                va_loss = tf.identity(va_loss)

            return va_loss
Exemplo n.º 2
0
    def virtual_adversarial_loss():
      """Computes virtual adversarial loss.

      Uses lm_inputs and constructs the language model graph if it hasn't yet
      been constructed.

      Also ensures that the LM input states are saved for LSTM state-saving
      BPTT.

      Returns:
        loss: float scalar.
      """
      if self.lm_inputs is None:
        self.language_model_graph(compute_loss=False)

      def logits_from_embedding(embedded, return_next_state=False):
        _, next_states, logits, _ = self.cl_loss_from_embedding(
            embedded, inputs=self.lm_inputs, return_intermediates=True)
        if return_next_state:
          return next_states, logits
        else:
          return logits

      lm_embedded = (self.tensors['lm_embedded'],
                     self.tensors['lm_embedded_reverse'])
      next_states, lm_cl_logits = logits_from_embedding(
          lm_embedded, return_next_state=True)

      va_loss = adv_lib.virtual_adversarial_loss_bidir(
          lm_cl_logits, lm_embedded, self.lm_inputs, logits_from_embedding)

      saves = [
          inp.save_state(state)
          for (inp, state) in zip(self.lm_inputs, next_states)
      ]
      with tf.control_dependencies(saves):
        va_loss = tf.identity(va_loss)

      return va_loss