Ejemplo n.º 1
0
 def forward(self, *inputs, num_beams=0):
     with torch.set_grad_enabled(self.training):
         encoder_inputs, decoder_inputs = assert_dims(
             inputs,
             [2, None, None])  # dims: [sl, bs] for encoder and decoder
         # reset the states for the new batch
         bs = encoder_inputs.size(1)
         self.encoder.reset(bs)
         self.decoder.reset(bs)
         outputs = self.encoder(encoder_inputs)
         state = concat_bidir_state(self.encoder.encoder_layer.hidden,
                                    cell_type=self.cell_type,
                                    nlayers=self.nlayers,
                                    bidir=self.bidir)
         if self.training:
             self.decoder.pr_force = self.pr_force
             nb = 1 if self.pr_force < 1 else 0
         else:
             nb = num_beams
         outputs_dec = self.decoder(decoder_inputs,
                                    hidden=state,
                                    num_beams=nb)
         predictions = outputs_dec[:decoder_inputs.size(
             0)] if num_beams == 0 else self.decoder.beam_outputs
     return predictions, [*outputs, *outputs_dec]
Ejemplo n.º 2
0
    def forward(self, *inputs, num_beams=0):
        with torch.set_grad_enabled(self.training):
            encoder_inputs, decoder_inputs = assert_dims(
                inputs,
                [2, None, None])  # dims: [sl, bs] for encoder and decoder
            # reset the states for the new batch
            num_utterances, max_sl, bs = encoder_inputs.size()
            self.reset_encoders(bs)
            outputs, session = self.encoder(encoder_inputs)
            self.encoder.query_encoder.reset(bs)
            decoder_outputs = self.encoder.query_encoder(decoder_inputs)
            decoder_out = concat_bidir_state(
                self.encoder.query_encoder_layer.get_last_hidden_state(),
                cell_type=self.cell_type,
                nlayers=1,
                bidir=self.encoder.bidir)
            x = torch.cat([session, decoder_out], dim=-1)
            prior_log_var, prior_mu, recog_log_var, recog_mu, session = self.variational_encoding(
                session, x)
            bow_logits = self.bow_network(session).squeeze(
                0) if num_beams == 0 else None

            state, constraints = self.encoder_hidden_state_projection(session)
            outputs_dec, predictions = self.decoding(decoder_inputs, num_beams,
                                                     state)
            if num_beams == 0:
                return [
                    predictions, recog_mu, recog_log_var, prior_mu,
                    prior_log_var, bow_logits
                ], [*outputs, *outputs_dec]
            else:
                return predictions, [*outputs, *outputs_dec]
Ejemplo n.º 3
0
def test_concat_bidirs(cell_type, input_size, output_size, bidir):
    cell = Cell(cell_type=cell_type, input_size=input_size, output_size=output_size, bidir=bidir)
    cell.reset(bs=32)
    output = concat_bidir_state(cell.hidden, bidir=bidir, cell_type=cell_type, nlayers=1)
    cell2 = Cell(cell_type=cell_type, input_size=input_size, output_size=output_size * 2 if bidir else output_size,
                 bidir=False)
    cell2.reset(bs=32)
    dec_state = cell2.hidden
    for layer_in, layer_out in zip(output, dec_state):
        if isinstance(layer_in, (tuple, list)):
            for h1, h2 in zip(layer_in, layer_out):
                assert h1.size() == h2.size()
        else:
            assert layer_in.size() == layer_out.size()
Ejemplo n.º 4
0
 def query_level_encoding(self, encoder_inputs):
     query_encoder_outputs = []
     for index, context in enumerate(encoder_inputs):
         self.query_encoder.reset(bs=encoder_inputs.size(2))
         state = self.query_encoder.hidden
         outputs = self.query_encoder(context,
                                      state)  # context has size [sl, bs]
         out = concat_bidir_state(
             self.query_encoder.encoder_layer.get_last_hidden_state(),
             cell_type=self.cell_type,
             nlayers=1,
             bidir=self.query_encoder.encoder_layer.bidir)
         query_encoder_outputs.append(
             out)  # get the last sl output of the query_encoder
         # BPTT if the dialogue is too long repackage the first half of the outputs to decrease
         # the gradient backpropagation and fit it into memory
         # out = repackage_var(outputs[-1][
         #                        -1]) if max_sl * num_utterances > self.BPTT_MAX_UTTERANCES and index <= num_utterances // 2 else \
         #    outputs[-1][-1]
     query_encoder_outputs = torch.cat(query_encoder_outputs,
                                       dim=0)  # [cl, bs, nhid]
     return query_encoder_outputs
Ejemplo n.º 5
0
 def forward(self, *inputs, num_beams=0):
     encoder_inputs, decoder_inputs = assert_dims(
         inputs, [2, None, None])  # dims: [sl, bs] for encoder and decoder
     # reset the states for the new batch
     bs = encoder_inputs.size(1)
     self.encoder.reset(bs)
     self.decoder.reset(bs)
     raw_outpus, outputs = self.encoder(encoder_inputs)
     state = concat_bidir_state(self.encoder.hidden)
     raw_outputs_dec, outputs_dec = self.decoder(decoder_inputs,
                                                 hidden=state,
                                                 num_beams=num_beams)
     if num_beams == 0:
         # use output of the projection module
         predictions = assert_dims(
             outputs_dec[-1], [None, bs, self.nt])  # dims: [sl, bs, nt]
     else:
         # use argmax or beam search predictions
         predictions = assert_dims(
             self.decoder.beam_outputs,
             [None, bs, num_beams])  # dims: [sl, bs, nb]
     return predictions, [*raw_outpus,
                          *raw_outputs_dec], [*outputs, *outputs_dec]
Ejemplo n.º 6
0
    def forward(self, *inputs, num_beams=0):
        encoder_inputs, decoder_inputs = assert_dims(
            inputs, [2, None, None])  # dims: [sl, bs] for encoder and decoder
        # reset the states for the new batch
        num_utterances, max_sl, bs = encoder_inputs.size()
        self.reset_encoders(bs)
        query_encoder_outputs = self.query_level_encoding(encoder_inputs)

        outputs = self.se_enc(query_encoder_outputs)
        session = self.se_enc.hidden[-1]
        self.query_encoder.reset(bs)
        decoder_outputs = self.query_encoder(decoder_inputs)
        decoder_out = concat_bidir_state(
            self.query_encoder.encoder_layer.hidden[-1],
            cell_type=self.cell_type,
            nlayers=1,
            bidir=self.query_encoder.encoder_layer.bidir)
        x = torch.cat([session, decoder_out], dim=-1)
        prior_log_var, prior_mu, recog_log_var, recog_mu, session = self.variational_encoding(
            session, x)
        bow_logits = self.bow_network(session).squeeze(
            0) if num_beams == 0 else None

        state = self.decoder.hidden
        # if there are multiple layers we set the state to the first layer and ignore all others
        # get the session_output of the last layer and the last step
        state[0] = self.decoder_state_linear(session)
        outputs_dec, predictions = self.decoding(decoder_inputs, num_beams,
                                                 state)
        if num_beams == 0:
            return [
                predictions, recog_mu, recog_log_var, prior_mu, prior_log_var,
                bow_logits
            ], [*outputs, *outputs_dec]
        else:
            return predictions, [*outputs, *outputs_dec]