def forward(self, *inputs, num_beams=0): with torch.set_grad_enabled(self.training): encoder_inputs, decoder_inputs = assert_dims( inputs, [2, None, None]) # dims: [sl, bs] for encoder and decoder # reset the states for the new batch bs = encoder_inputs.size(1) self.encoder.reset(bs) self.decoder.reset(bs) outputs = self.encoder(encoder_inputs) state = concat_bidir_state(self.encoder.encoder_layer.hidden, cell_type=self.cell_type, nlayers=self.nlayers, bidir=self.bidir) if self.training: self.decoder.pr_force = self.pr_force nb = 1 if self.pr_force < 1 else 0 else: nb = num_beams outputs_dec = self.decoder(decoder_inputs, hidden=state, num_beams=nb) predictions = outputs_dec[:decoder_inputs.size( 0)] if num_beams == 0 else self.decoder.beam_outputs return predictions, [*outputs, *outputs_dec]
def forward(self, *inputs, num_beams=0): with torch.set_grad_enabled(self.training): encoder_inputs, decoder_inputs = assert_dims( inputs, [2, None, None]) # dims: [sl, bs] for encoder and decoder # reset the states for the new batch num_utterances, max_sl, bs = encoder_inputs.size() self.reset_encoders(bs) outputs, session = self.encoder(encoder_inputs) self.encoder.query_encoder.reset(bs) decoder_outputs = self.encoder.query_encoder(decoder_inputs) decoder_out = concat_bidir_state( self.encoder.query_encoder_layer.get_last_hidden_state(), cell_type=self.cell_type, nlayers=1, bidir=self.encoder.bidir) x = torch.cat([session, decoder_out], dim=-1) prior_log_var, prior_mu, recog_log_var, recog_mu, session = self.variational_encoding( session, x) bow_logits = self.bow_network(session).squeeze( 0) if num_beams == 0 else None state, constraints = self.encoder_hidden_state_projection(session) outputs_dec, predictions = self.decoding(decoder_inputs, num_beams, state) if num_beams == 0: return [ predictions, recog_mu, recog_log_var, prior_mu, prior_log_var, bow_logits ], [*outputs, *outputs_dec] else: return predictions, [*outputs, *outputs_dec]
def test_concat_bidirs(cell_type, input_size, output_size, bidir): cell = Cell(cell_type=cell_type, input_size=input_size, output_size=output_size, bidir=bidir) cell.reset(bs=32) output = concat_bidir_state(cell.hidden, bidir=bidir, cell_type=cell_type, nlayers=1) cell2 = Cell(cell_type=cell_type, input_size=input_size, output_size=output_size * 2 if bidir else output_size, bidir=False) cell2.reset(bs=32) dec_state = cell2.hidden for layer_in, layer_out in zip(output, dec_state): if isinstance(layer_in, (tuple, list)): for h1, h2 in zip(layer_in, layer_out): assert h1.size() == h2.size() else: assert layer_in.size() == layer_out.size()
def query_level_encoding(self, encoder_inputs): query_encoder_outputs = [] for index, context in enumerate(encoder_inputs): self.query_encoder.reset(bs=encoder_inputs.size(2)) state = self.query_encoder.hidden outputs = self.query_encoder(context, state) # context has size [sl, bs] out = concat_bidir_state( self.query_encoder.encoder_layer.get_last_hidden_state(), cell_type=self.cell_type, nlayers=1, bidir=self.query_encoder.encoder_layer.bidir) query_encoder_outputs.append( out) # get the last sl output of the query_encoder # BPTT if the dialogue is too long repackage the first half of the outputs to decrease # the gradient backpropagation and fit it into memory # out = repackage_var(outputs[-1][ # -1]) if max_sl * num_utterances > self.BPTT_MAX_UTTERANCES and index <= num_utterances // 2 else \ # outputs[-1][-1] query_encoder_outputs = torch.cat(query_encoder_outputs, dim=0) # [cl, bs, nhid] return query_encoder_outputs
def forward(self, *inputs, num_beams=0): encoder_inputs, decoder_inputs = assert_dims( inputs, [2, None, None]) # dims: [sl, bs] for encoder and decoder # reset the states for the new batch bs = encoder_inputs.size(1) self.encoder.reset(bs) self.decoder.reset(bs) raw_outpus, outputs = self.encoder(encoder_inputs) state = concat_bidir_state(self.encoder.hidden) raw_outputs_dec, outputs_dec = self.decoder(decoder_inputs, hidden=state, num_beams=num_beams) if num_beams == 0: # use output of the projection module predictions = assert_dims( outputs_dec[-1], [None, bs, self.nt]) # dims: [sl, bs, nt] else: # use argmax or beam search predictions predictions = assert_dims( self.decoder.beam_outputs, [None, bs, num_beams]) # dims: [sl, bs, nb] return predictions, [*raw_outpus, *raw_outputs_dec], [*outputs, *outputs_dec]
def forward(self, *inputs, num_beams=0): encoder_inputs, decoder_inputs = assert_dims( inputs, [2, None, None]) # dims: [sl, bs] for encoder and decoder # reset the states for the new batch num_utterances, max_sl, bs = encoder_inputs.size() self.reset_encoders(bs) query_encoder_outputs = self.query_level_encoding(encoder_inputs) outputs = self.se_enc(query_encoder_outputs) session = self.se_enc.hidden[-1] self.query_encoder.reset(bs) decoder_outputs = self.query_encoder(decoder_inputs) decoder_out = concat_bidir_state( self.query_encoder.encoder_layer.hidden[-1], cell_type=self.cell_type, nlayers=1, bidir=self.query_encoder.encoder_layer.bidir) x = torch.cat([session, decoder_out], dim=-1) prior_log_var, prior_mu, recog_log_var, recog_mu, session = self.variational_encoding( session, x) bow_logits = self.bow_network(session).squeeze( 0) if num_beams == 0 else None state = self.decoder.hidden # if there are multiple layers we set the state to the first layer and ignore all others # get the session_output of the last layer and the last step state[0] = self.decoder_state_linear(session) outputs_dec, predictions = self.decoding(decoder_inputs, num_beams, state) if num_beams == 0: return [ predictions, recog_mu, recog_log_var, prior_mu, prior_log_var, bow_logits ], [*outputs, *outputs_dec] else: return predictions, [*outputs, *outputs_dec]