def _decode(self, tokens, encoder_outs): # wrap in Variable tokens = utils.volatile_variable(tokens) avg_probs = None avg_attn = None for model, encoder_out in zip(self.models, encoder_outs): with utils.maybe_no_grad(): decoder_out, attn = model.decoder(tokens, encoder_out) probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None: attn = attn[:, -1, :].data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs.div_(len(self.models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(self.models)) return avg_probs, avg_attn
def _decode(self, tokens, encoder_outs, src_doctopic_reshaped, incremental_states): # print(tokens, encoder_outs, src_doctopic_reshaped.size(), incremental_states) # wrap in Variable tokens = utils.volatile_variable(tokens) avg_probs = None avg_attn = None for model, encoder_out in zip(self.models, encoder_outs): with utils.maybe_no_grad(): decoder_out, attn = model.decoder(tokens, encoder_out, src_doctopic_reshaped, incremental_states[model]) probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None: attn = attn[:, -1, :].data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs.div_(len(self.models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(self.models)) return avg_probs, avg_attn
def _decode(self, tokens, encoder_outs, incremental_states): # wrap in Variable tokens = utils.volatile_variable(tokens) avg_probs = None avg_attn = None for model, encoder_out in zip(self.models, encoder_outs): with utils.maybe_no_grad(): if incremental_states[model] is not None: decoder_out = list( model.decoder(tokens, encoder_out, incremental_states[model])) else: decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out[0] = decoder_out[0][:, -1, :] attn = decoder_out[1] probs = model.get_normalized_probs(decoder_out, log_probs=False).data if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None: attn = attn[:, -1, :].data if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs.div_(len(self.models)) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(self.models)) return avg_probs, avg_attn
def forward(self, input, incremental_state=None): """ Input: Time x Batch x Channel. Args: incremental_state: Used to buffer signal; if not None, then input is expected to contain a single frame. If the input order changes between time steps, call reorder_incremental_state. """ if incremental_state is None: return super().forward(input) # reshape weight weight = self._get_linearized_weight() kw = self.kernel_size[0] bsz = input.size(0) # input: bsz x len x dim if kw > 1: input = input.data input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = input.new(bsz, kw, input.size(2)).zero_() self._set_input_buffer(incremental_state, input_buffer) else: # shift buffer input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() # append next input input_buffer[:, -1, :] = input[:, -1, :] input = utils.volatile_variable(input_buffer) with utils.maybe_no_grad(): output = F.linear(input.view(bsz, -1), weight, self.bias) return output.view(bsz, 1, -1)
def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_output = grad_output.data.contiguous() grad_input = grad_output.new(ctx.input_size).zero_() grad_weight = grad_output.new(ctx.weight_size).zero_() grad_bias = grad_output.new(ctx.weight_size[2]) temporal_convolution_tbc.TemporalConvolutionTBC_backward( input.type().encode('utf-8'), grad_output, grad_input, grad_weight, grad_bias, input, weight) grad_input = utils.volatile_variable(grad_input) grad_weight = utils.volatile_variable(grad_weight) grad_bias = utils.volatile_variable(grad_bias) return grad_input, grad_weight, grad_bias, None
def _decode(self, tokens, encoder_outs, incremental_states, n_srcs=1): # wrap in Variable tokens = utils.volatile_variable(tokens) # Source sentences are weighted equally (for now) srcs_weights = [1 / n_srcs] * n_srcs avg_probs = None avg_attn = None for src_id, src_weight in enumerate(srcs_weights): for model_id, (model_weight, model) in enumerate( zip(self.model_weights, self.models) ): with utils.maybe_no_grad(): encoder_out = encoder_outs[src_id][model_id] incremental_state = incremental_states[(src_id, model_id)] decoder_out = list( model.decoder(tokens, encoder_out, incremental_state) ) decoder_out[0] = decoder_out[0][:, -1, :] attn = decoder_out[1] if len(decoder_out) == 3: possible_translation_tokens = decoder_out[2] else: possible_translation_tokens = None probs = ( src_weight * model_weight * model.get_normalized_probs(decoder_out, log_probs=False) ) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None and src_id == self.align_to: attn = attn[:, -1, :] if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(self.models)) return avg_probs, avg_attn, possible_translation_tokens
def _decode(self, tokens, encoder_outs, incremental_states): # wrap in Variable tokens = utils.volatile_variable(tokens) avg_probs = None avg_attn = None for model_weight, model, encoder_out in zip( self.model_weights, self.models, encoder_outs ): with utils.maybe_no_grad(): decoder_out = list( model.decoder(tokens, encoder_out, incremental_states[model]) ) decoder_out[0] = decoder_out[0][:, -1, :] attn = decoder_out[1] if len(decoder_out) == 3: possible_translation_tokens = decoder_out[2] else: possible_translation_tokens = None probs = model_weight * model.get_normalized_probs( decoder_out, log_probs=False ) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) if attn is not None: attn = attn[:, -1, :] if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs.log_() if avg_attn is not None: avg_attn.div_(len(self.models)) return avg_probs, avg_attn, possible_translation_tokens
def backward(ctx, grad): grad_input = ctx.grad_input if not isinstance(grad_input, torch.autograd.Variable): grad_input = utils.volatile_variable(grad_input) return grad_input * grad, None, None, None, None, None