Example #1
0
    def batch_beam_search(self,
                          x,
                          beam_size=5,
                          max_length=255,
                          n_best=1,
                          length_penalty=.2):
        # |x[0]| = (batch_size, n)
        batch_size = x[0].size(0)

        mask = self._generate_mask(x[0], x[1])
        # |mask| = (batch_size, n)
        x = x[0]

        mask_enc = torch.stack([mask for _ in range(x.size(1))], dim=1)
        mask_dec = mask.unsqueeze(1)
        # |mask_enc| = (batch_size, n, n)
        # |mask_dec| = (batch_size, 1, n)

        z = self.emb_dropout(self._position_encoding(self.emb_enc(x)))
        z, _ = self.encoder(z, mask_enc)
        # |z| = (batch_size, n, hidden_size)

        spaces = [
            SingleBeamSearchSpace(
                z.device,
                [('prev_state_%d' % j, None, 0)
                 for j in range(len(self.decoder._modules) + 1)],
                beam_size=beam_size,
                max_length=max_length,
            ) for i in range(batch_size)
        ]
        done_cnt = [space.is_done() for space in spaces]

        length = 0
        while sum(done_cnt) < batch_size and length <= max_length:
            fab_input, fab_z, fab_mask = [], [], []
            fab_prevs = [[] for _ in range(len(self.decoder._modules) + 1)]

            for i, space in enumerate(spaces):
                if space.is_done() == 0:
                    tmp = space.get_batch()

                    y_hat_ = tmp[0]
                    tmp = tmp[1:]

                    fab_input += [y_hat_]
                    for j, prev_ in enumerate(tmp):
                        if prev_ is not None:
                            fab_prevs[j] += [prev_]
                        else:
                            fab_prevs[j] = None

                    fab_z += [z[i].unsqueeze(0)] * beam_size
                    fab_mask += [mask_dec[i].unsqueeze(0)] * beam_size

            fab_input = torch.cat(fab_input, dim=0)
            for i, fab_prev in enumerate(fab_prevs):
                if fab_prev is not None:
                    fab_prevs[i] = torch.cat(fab_prev, dim=0)
            fab_z = torch.cat(fab_z, dim=0)
            fab_mask = torch.cat(fab_mask, dim=0)
            # |fab_input| = (current_batch_size, 1,)
            # |fab_prevs[i]| = (current_batch_size, length, hidden_size)
            # |fab_z| = (current_batch_size, n, hidden_size)
            # |fab_mask| = (current_batch_size, 1, n)

            # Unlike training procedure,
            # take the last time-step's output during the inference.
            h_t = self.emb_dropout(
                self._position_encoding(self.emb_dec(fab_input),
                                        init_pos=length))
            # |h_t| = (current_batch_size, 1, hidden_size)
            if fab_prevs[0] is None:
                fab_prevs[0] = h_t
            else:
                fab_prevs[0] = torch.cat([fab_prevs[0], h_t], dim=1)

            for i, block in enumerate(self.decoder._modules.values()):
                prev = fab_prevs[i]
                # |prev| = (current_batch_size, m, hidden_size)

                h_t, _, _, _ = block(h_t, fab_z, fab_mask, prev)
                # |h_t| = (current_batch_size, 1, hidden_size)

                if fab_prevs[i + 1] is None:
                    fab_prevs[i + 1] = h_t
                else:
                    fab_prevs[i + 1] = torch.cat([fab_prevs[i + 1], h_t],
                                                 dim=1)

            y_hat_t = self.softmax(self.generator(h_t))
            # |y_hat_t| = (batch_size, 1, output_size)

            cnt = 0
            for space in spaces:
                if space.is_done() == 0:
                    from_index = cnt * beam_size
                    to_index = from_index + beam_size

                    space.collect_result(
                        y_hat_t[from_index:to_index],
                        [(
                            'prev_state_%d' % i,
                            fab_prevs[i][from_index:to_index],
                        ) for i in range(len(self.decoder._modules) + 1)],
                    )

                    cnt += 1

            done_cnt = [space.is_done() for space in spaces]
            length += 1

        batch_sentences = []
        batch_probs = []

        for i, space in enumerate(spaces):
            sentences, probs = space.get_n_best(n_best,
                                                length_penalty=length_penalty)

            batch_sentences += [sentences]
            batch_probs += [probs]

        return batch_sentences, batch_probs
Example #2
0
    def batch_beam_search(self,
                          src,
                          beam_size=5,
                          max_length=255,
                          n_best=1,
                          length_penalty=.2):
        mask, x_length = None, None

        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            # |mask| = (batch_size, length)
        else:
            x = src
        batch_size = x.size(0)

        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        # |h_src| = (batch_size, length, hidden_size)
        h_0_tgt, c_0_tgt = h_0_tgt
        h_0_tgt = h_0_tgt.transpose(0, 1).contiguous().view(
            batch_size, -1, self.hidden_size).transpose(0, 1).contiguous()
        c_0_tgt = c_0_tgt.transpose(0, 1).contiguous().view(
            batch_size, -1, self.hidden_size).transpose(0, 1).contiguous()
        # |h_0_tgt| = (n_layers, batch_size, hidden_size)
        h_0_tgt = (h_0_tgt, c_0_tgt)

        # initialize 'SingleBeamSearchSpace' as many as batch_size
        spaces = [
            SingleBeamSearchSpace(
                h_src.device,
                [
                    ('hidden_state', h_0_tgt[0][:, i, :].unsqueeze(1), 1),
                    ('cell_state', h_0_tgt[1][:, i, :].unsqueeze(1), 1),
                    ('h_t_1_tilde', None, 0),
                ],
                beam_size=beam_size,
                max_length=max_length,
            ) for i in range(batch_size)
        ]
        done_cnt = [space.is_done() for space in spaces]

        length = 0
        # Run loop while sum of 'done_cnt' is smaller than batch_size,
        # or length is still smaller than max_length.
        while sum(done_cnt) < batch_size and length <= max_length:
            # current_batch_size = sum(done_cnt) * beam_size

            # Initialize fabricated variables.
            # As far as batch-beam-search is running,
            # temporary batch-size for fabricated mini-batch is
            # 'beam_size'-times bigger than original batch_size.
            fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
            fab_h_src, fab_mask = [], []

            # Build fabricated mini-batch in non-parallel way.
            # This may cause a bottle-neck.
            for i, space in enumerate(spaces):
                # Batchify if the inference for the sample is still not finished.
                if space.is_done() == 0:
                    y_hat_, hidden_, cell_, h_t_tilde_ = space.get_batch()

                    fab_input += [y_hat_]
                    fab_hidden += [hidden_]
                    fab_cell += [cell_]
                    if h_t_tilde_ is not None:
                        fab_h_t_tilde += [h_t_tilde_]
                    else:
                        fab_h_t_tilde = None

                    fab_h_src += [h_src[i, :, :]] * beam_size
                    fab_mask += [mask[i, :]] * beam_size

            # Now, concatenate list of tensors.
            fab_input = torch.cat(fab_input, dim=0)
            fab_hidden = torch.cat(fab_hidden, dim=1)
            fab_cell = torch.cat(fab_cell, dim=1)
            if fab_h_t_tilde is not None:
                fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
            fab_h_src = torch.stack(fab_h_src)
            fab_mask = torch.stack(fab_mask)
            # |fab_input| = (current_batch_size, 1)
            # |fab_hidden| = (n_layers, current_batch_size, hidden_size)
            # |fab_cell| = (n_layers, current_batch_size, hidden_size)
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
            # |fab_h_src| = (current_batch_size, length, hidden_size)
            # |fab_mask| = (current_batch_size, length)

            emb_t = self.emb_dec(fab_input)
            # |emb_t| = (current_batch_size, 1, word_vec_dim)

            fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(
                emb_t, fab_h_t_tilde, (fab_hidden, fab_cell))
            # |fab_decoder_output| = (current_batch_size, 1, hidden_size)
            context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
            # |context_vector| = (current_batch_size, 1, hidden_size)
            fab_h_t_tilde = self.tanh(
                self.concat(
                    torch.cat([fab_decoder_output, context_vector], dim=-1)))
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
            y_hat = self.generator(fab_h_t_tilde)
            # |y_hat| = (current_batch_size, 1, output_size)

            # separate the result for each sample.
            # fab_hidden[:, from_index:to_index, :] = (n_layers, beam_size, hidden_size)
            # fab_cell[:, from_index:to_index, :] = (n_layers, beam_size, hidden_size)
            # fab_h_t_tilde[from_index:to_index] = (beam_size, 1, hidden_size)
            cnt = 0
            for space in spaces:
                if space.is_done() == 0:
                    # Decide a range of each sample.
                    from_index = cnt * beam_size
                    to_index = from_index + beam_size

                    # pick k-best results for each sample.
                    space.collect_result(
                        y_hat[from_index:to_index],
                        [
                            ('hidden_state',
                             fab_hidden[:, from_index:to_index, :]),
                            ('cell_state', fab_cell[:,
                                                    from_index:to_index, :]),
                            ('h_t_1_tilde',
                             fab_h_t_tilde[from_index:to_index]),
                        ],
                    )
                    cnt += 1

            done_cnt = [space.is_done() for space in spaces]
            length += 1

        # pick n-best hypothesis.
        batch_sentences = []
        batch_probs = []

        # Collect the results.
        for i, space in enumerate(spaces):
            sentences, probs = space.get_n_best(n_best,
                                                length_penalty=length_penalty)

            batch_sentences += [sentences]
            batch_probs += [probs]

        return batch_sentences, batch_probs
Example #3
0
    def batch_beam_search(self, src, beam_size=5, max_length=255, n_best=1):
        mask = None
        x_length = None
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            # |mask| = (batch_size, length)
        else:
            x = src
        batch_size = x.size(0)

        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        # |h_src| = (batch_size, length, hidden_size)
        h_0_tgt, c_0_tgt = h_0_tgt
        h_0_tgt = h_0_tgt.transpose(0, 1).contiguous().view(
            batch_size, -1, self.hidden_size).transpose(0, 1).contiguous()
        c_0_tgt = c_0_tgt.transpose(0, 1).contiguous().view(
            batch_size, -1, self.hidden_size).transpose(0, 1).contiguous()
        # |h_0_tgt| = (n_layers, batch_size, hidden_size)
        h_0_tgt = (h_0_tgt, c_0_tgt)

        # initialize beam-search.
        spaces = [
            SingleBeamSearchSpace((h_0_tgt[0][:, i, :].unsqueeze(1),
                                   h_0_tgt[1][:, i, :].unsqueeze(1)),
                                  None,
                                  beam_size,
                                  max_length=max_length)
            for i in range(batch_size)
        ]
        done_cnt = [space.is_done() for space in spaces]

        length = 0
        while sum(done_cnt) < batch_size and length <= max_length:
            # current_batch_size = sum(done_cnt) * beam_size

            # initialize fabricated variables.
            fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
            fab_h_src, fab_mask = [], []

            # batchify.
            for i, space in enumerate(spaces):
                if space.is_done() == 0:
                    y_hat_, (hidden_, cell_), h_t_tilde_ = space.get_batch()

                    fab_input += [y_hat_]
                    fab_hidden += [hidden_]
                    fab_cell += [cell_]
                    if h_t_tilde_ is not None:
                        fab_h_t_tilde += [h_t_tilde_]
                    else:
                        fab_h_t_tilde = None

                    fab_h_src += [h_src[i, :, :]] * beam_size
                    fab_mask += [mask[i, :]] * beam_size

            fab_input = torch.cat(fab_input, dim=0)
            fab_hidden = torch.cat(fab_hidden, dim=1)
            fab_cell = torch.cat(fab_cell, dim=1)
            if fab_h_t_tilde is not None:
                fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
            fab_h_src = torch.stack(fab_h_src)
            fab_mask = torch.stack(fab_mask)
            # |fab_input| = (current_batch_size, 1)
            # |fab_hidden| = (n_layers, current_batch_size, hidden_size)
            # |fab_cell| = (n_layers, current_batch_size, hidden_size)
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
            # |fab_h_src| = (current_batch_size, length, hidden_size)
            # |fab_mask| = (current_batch_size, length)

            emb_t = self.emb_dec(fab_input)
            # |emb_t| = (current_batch_size, 1, word_vec_dim)

            fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(
                emb_t, fab_h_t_tilde, (fab_hidden, fab_cell))
            # |fab_decoder_output| = (current_batch_size, 1, hidden_size)
            context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
            # |context_vector| = (current_batch_size, 1, hidden_size)
            fab_h_t_tilde = self.tanh(
                self.concat(
                    torch.cat([fab_decoder_output, context_vector], dim=-1)))
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
            y_hat = self.generator(fab_h_t_tilde)
            # |y_hat| = (current_batch_size, 1, output_size)

            # separate the result for each sample.
            cnt = 0
            for space in spaces:
                if space.is_done() == 0:
                    from_index = cnt * beam_size
                    to_index = (cnt + 1) * beam_size

                    # pick k-best results for each sample.
                    space.collect_result(
                        y_hat[from_index:to_index],
                        (fab_hidden[:, from_index:to_index, :],
                         fab_cell[:, from_index:to_index, :]),
                        fab_h_t_tilde[from_index:to_index])
                    cnt += 1

            done_cnt = [space.is_done() for space in spaces]
            length += 1

        # pick n-best hypothesis.
        batch_sentences = []
        batch_probs = []

        for i, space in enumerate(spaces):
            sentences, probs = space.get_n_best(n_best)

            batch_sentences += [sentences]
            batch_probs += [probs]

        return batch_sentences, batch_probs