Ejemplo n.º 1
0
    def _decode(self, tokens, encoder_outs):
        # wrap in Variable
        tokens = utils.volatile_variable(tokens)

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            with utils.maybe_no_grad():
                decoder_out, attn = model.decoder(tokens, encoder_out)
            probs = model.get_normalized_probs(decoder_out[:, -1, :],
                                               log_probs=False).data
            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None:
                attn = attn[:, -1, :].data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs.div_(len(self.models))
        avg_probs.log_()
        if avg_attn is not None:
            avg_attn.div_(len(self.models))

        return avg_probs, avg_attn
Ejemplo n.º 2
0
    def _decode(self, tokens, encoder_outs, src_doctopic_reshaped, incremental_states):

        # print(tokens, encoder_outs, src_doctopic_reshaped.size(), incremental_states)
        
        # wrap in Variable
        tokens = utils.volatile_variable(tokens)

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            with utils.maybe_no_grad():
                
                decoder_out, attn = model.decoder(tokens, encoder_out, src_doctopic_reshaped, incremental_states[model])
            probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None:
                attn = attn[:, -1, :].data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs.div_(len(self.models))
        avg_probs.log_()
        if avg_attn is not None:
            avg_attn.div_(len(self.models))

        return avg_probs, avg_attn
    def _decode(self, tokens, encoder_outs, incremental_states):
        # wrap in Variable
        tokens = utils.volatile_variable(tokens)

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            with utils.maybe_no_grad():
                if incremental_states[model] is not None:
                    decoder_out = list(
                        model.decoder(tokens, encoder_out,
                                      incremental_states[model]))
                else:
                    decoder_out = list(model.decoder(tokens, encoder_out))
                decoder_out[0] = decoder_out[0][:, -1, :]
                attn = decoder_out[1]
            probs = model.get_normalized_probs(decoder_out,
                                               log_probs=False).data
            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None:
                attn = attn[:, -1, :].data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs.div_(len(self.models))
        avg_probs.log_()
        if avg_attn is not None:
            avg_attn.div_(len(self.models))

        return avg_probs, avg_attn
Ejemplo n.º 4
0
    def forward(self, input, incremental_state=None):
        """
        Input: Time x Batch x Channel.
        Args:
            incremental_state: Used to buffer signal; if not None, then input is
                expected to contain a single frame. If the input order changes
                between time steps, call reorder_incremental_state.
        """
        if incremental_state is None:
            return super().forward(input)

        # reshape weight
        weight = self._get_linearized_weight()
        kw = self.kernel_size[0]

        bsz = input.size(0)  # input: bsz x len x dim
        if kw > 1:
            input = input.data
            input_buffer = self._get_input_buffer(incremental_state)
            if input_buffer is None:
                input_buffer = input.new(bsz, kw, input.size(2)).zero_()
                self._set_input_buffer(incremental_state, input_buffer)
            else:
                # shift buffer
                input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
            # append next input
            input_buffer[:, -1, :] = input[:, -1, :]
            input = utils.volatile_variable(input_buffer)
        with utils.maybe_no_grad():
            output = F.linear(input.view(bsz, -1), weight, self.bias)
        return output.view(bsz, 1, -1)
Ejemplo n.º 5
0
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors

        grad_output = grad_output.data.contiguous()
        grad_input = grad_output.new(ctx.input_size).zero_()
        grad_weight = grad_output.new(ctx.weight_size).zero_()
        grad_bias = grad_output.new(ctx.weight_size[2])

        temporal_convolution_tbc.TemporalConvolutionTBC_backward(
            input.type().encode('utf-8'), grad_output, grad_input, grad_weight,
            grad_bias, input, weight)

        grad_input = utils.volatile_variable(grad_input)
        grad_weight = utils.volatile_variable(grad_weight)
        grad_bias = utils.volatile_variable(grad_bias)

        return grad_input, grad_weight, grad_bias, None
Ejemplo n.º 6
0
    def _decode(self, tokens, encoder_outs, incremental_states, n_srcs=1):
        # wrap in Variable
        tokens = utils.volatile_variable(tokens)

        # Source sentences are weighted equally (for now)
        srcs_weights = [1 / n_srcs] * n_srcs

        avg_probs = None
        avg_attn = None
        for src_id, src_weight in enumerate(srcs_weights):
            for model_id, (model_weight, model) in enumerate(
                zip(self.model_weights, self.models)
            ):
                with utils.maybe_no_grad():
                    encoder_out = encoder_outs[src_id][model_id]
                    incremental_state = incremental_states[(src_id, model_id)]
                    decoder_out = list(
                        model.decoder(tokens, encoder_out, incremental_state)
                    )
                    decoder_out[0] = decoder_out[0][:, -1, :]
                    attn = decoder_out[1]
                    if len(decoder_out) == 3:
                        possible_translation_tokens = decoder_out[2]
                    else:
                        possible_translation_tokens = None
                probs = (
                    src_weight
                    * model_weight
                    * model.get_normalized_probs(decoder_out, log_probs=False)
                )
                if avg_probs is None:
                    avg_probs = probs
                else:
                    avg_probs.add_(probs)
                if attn is not None and src_id == self.align_to:
                    attn = attn[:, -1, :]
                    if avg_attn is None:
                        avg_attn = attn
                    else:
                        avg_attn.add_(attn)
        avg_probs.log_()
        if avg_attn is not None:
            avg_attn.div_(len(self.models))

        return avg_probs, avg_attn, possible_translation_tokens
Ejemplo n.º 7
0
    def _decode(self, tokens, encoder_outs, incremental_states):
        # wrap in Variable
        tokens = utils.volatile_variable(tokens)

        avg_probs = None
        avg_attn = None
        for model_weight, model, encoder_out in zip(
            self.model_weights, self.models, encoder_outs
        ):
            with utils.maybe_no_grad():
                decoder_out = list(
                    model.decoder(tokens, encoder_out, incremental_states[model])
                )
                decoder_out[0] = decoder_out[0][:, -1, :]
                attn = decoder_out[1]
                if len(decoder_out) == 3:
                    possible_translation_tokens = decoder_out[2]
                else:
                    possible_translation_tokens = None
            probs = model_weight * model.get_normalized_probs(
                decoder_out, log_probs=False
            )
            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None:
                attn = attn[:, -1, :]
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs.log_()
        if avg_attn is not None:
            avg_attn.div_(len(self.models))

        return avg_probs, avg_attn, possible_translation_tokens
 def backward(ctx, grad):
     grad_input = ctx.grad_input
     if not isinstance(grad_input, torch.autograd.Variable):
         grad_input = utils.volatile_variable(grad_input)
     return grad_input * grad, None, None, None, None, None