Ejemplo n.º 1
0
    def train_greedy_decode(self, inpute, mask=None, max_len=512):

        ence = self.enc(inpute, mask)

        bsize = inpute.size(0)

        # out: input to the decoder for the first step (bsize, 1)

        out = inpute.new_full((bsize, 1), 3)

        done_trans = None

        for i in range(0, max_len):

            _out = self.dec(ence, out, mask)

            _out = _out.argmax(dim=-1)

            wds = _out.narrow(1, _out.size(1) - 1, 1)

            out = torch.cat((out, wds), -1)

            # done_trans: (bsize)
            done_trans = wds.squeeze(1).eq(2) if done_trans is None else (
                done_trans | wds.squeeze(1).eq(2))

            if all_done(done_trans, bsize):
                break

        return out.narrow(1, 1, out.size(1) - 1)
Ejemplo n.º 2
0
	def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False, sample=False):

		bsize = inpute.size(0)

		sos_emb = self.get_sos_emb(inpute)

		sqrt_isize = sqrt(sos_emb.size(-1))

		out = sos_emb * sqrt_isize
		if self.pemb is not None:
			 out = out + self.pemb.get_pos(0)

		if self.drop is not None:
			out = self.drop(out)

		states = {}

		for _tmp, net in enumerate(self.nets):
			out, _state = net(inpute, None, inputc, src_pad_mask, context_mask, None, out)
			states[_tmp] = _state

		if self.out_normer is not None:
			out = self.out_normer(out)

		out = self.classifier(out)
		wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

		trans = [wds]

		done_trans = wds.eq(2)

		for i in range(1, max_len):

			out = self.wemb(wds) * sqrt_isize
			if self.pemb is not None:
				out = out + self.pemb.get_pos(i)

			if self.drop is not None:
				out = self.drop(out)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(inpute, states[_tmp], inputc, src_pad_mask, None, context_mask, out)
				states[_tmp] = _state

			if self.out_normer is not None:
				out = self.out_normer(out)

			out = self.classifier(out)
			wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

			trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds)

			done_trans = done_trans | wds.eq(2)
			if all_done(done_trans, bsize):
				break

		return torch.cat(trans, 1)
Ejemplo n.º 3
0
	def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=False, fill_pad=False):

		bsize, seql = inpute.size()[:2]

		beam_size2 = beam_size * beam_size
		bsizeb2 = bsize * beam_size2
		real_bsize = bsize * beam_size

		sos_emb = self.get_sos_emb(inpute)
		isize = sos_emb.size(-1)
		sqrt_isize = sqrt(isize)

		if length_penalty > 0.0:
			# lpv: length penalty vector for each beam (bsize * beam_size, 1)
			lpv = sos_emb.new_ones(real_bsize, 1)
			lpv_base = 6.0 ** length_penalty

		out = sos_emb * sqrt_isize
		if self.pemb is not None:
			 out = out + self.pemb.get_pos(0)

		if self.drop is not None:
			out = self.drop(out)

		states = {}

		for _tmp, net in enumerate(self.nets):
			out, _state = net(inpute, None, src_pad_mask, None, out)
			states[_tmp] = _state

		if self.out_normer is not None:
			out = self.out_normer(out)

		# out: (bsize, 1, nwd)

		out = self.lsm(self.classifier(out))

		# scores: (bsize, 1, beam_size) => (bsize, beam_size)
		# wds: (bsize * beam_size, 1)
		# trans: (bsize * beam_size, 1)

		scores, wds = out.topk(beam_size, dim=-1)
		scores = scores.squeeze(1)
		sum_scores = scores
		wds = wds.view(real_bsize, 1)
		trans = wds

		# done_trans: (bsize, beam_size)

		done_trans = wds.view(bsize, beam_size).eq(2)

		# inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize)

		inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize)

		# _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql)

		_src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql)

		# states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize)

		for key, value in states.items():
			states[key] = repeat_bsize_for_beam_tensor(value, beam_size)

		for step in range(1, max_len):

			out = self.wemb(wds) * sqrt_isize
			if self.pemb is not None:
				out = out + self.pemb.get_pos(step)

			if self.drop is not None:
				out = self.drop(out)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(inpute, states[_tmp], _src_pad_mask, None, out)
				states[_tmp] = _state

			if self.out_normer is not None:
				out = self.out_normer(out)

			# out: (bsize, beam_size, nwd)

			out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1)

			# find the top k ** 2 candidates and calculate route scores for them
			# _scores: (bsize, beam_size, beam_size)
			# done_trans: (bsize, beam_size)
			# scores: (bsize, beam_size)
			# _wds: (bsize, beam_size, beam_size)
			# mask_from_done_trans: (bsize, beam_size) => (bsize, beam_size * beam_size)
			# added_scores: (bsize, 1, beam_size) => (bsize, beam_size, beam_size)

			_scores, _wds = out.topk(beam_size, dim=-1)
			_scores = (_scores.masked_fill(done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size))

			if length_penalty > 0.0:
				lpv = lpv.masked_fill(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base)

			# clip from k ** 2 candidate and remain the top-k for each path
			# scores: (bsize, beam_size * beam_size) => (bsize, beam_size)
			# _inds: indexes for the top-k candidate (bsize, beam_size)

			if clip_beam and (length_penalty > 0.0):
				scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1)
				_tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
				sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size)
			else:
				scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1)
				_tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
				sum_scores = scores

			# select the top-k candidate with higher route score and update translation record
			# wds: (bsize, beam_size, beam_size) => (bsize * beam_size, 1)

			wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1)

			# reduces indexes in _inds from (beam_size ** 2) to beam_size
			# thus the fore path of the top-k candidate is pointed out
			# _inds: indexes for the top-k candidate (bsize, beam_size)

			_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

			# select the corresponding translation history for the top-k candidate and update translation records
			# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)

			trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), 0) if fill_pad else wds), 1)

			done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size)

			# check early stop for beam search
			# done_trans: (bsize, beam_size)
			# scores: (bsize, beam_size)

			_done = False
			if length_penalty > 0.0:
				lpv = lpv.index_select(0, _inds)
			elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
				_done = True

			# check beam states(done or not)

			if _done or all_done(done_trans, real_bsize):
				break

			# update the corresponding hidden states
			# states[i]: (bsize * beam_size, nquery, isize)
			# _inds: (bsize, beam_size) => (bsize * beam_size)

			for key, value in states.items():
				states[key] = value.index_select(0, _inds)

		# if length penalty is only applied in the last step, apply length penalty
		if (not clip_beam) and (length_penalty > 0.0):
			scores = scores / lpv.view(bsize, beam_size)
			scores, _inds = scores.topk(beam_size, dim=-1)
			_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
			trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

		if return_all:

			return trans, scores
		else:

			return trans.view(bsize, beam_size, -1).select(1, 0)
Ejemplo n.º 4
0
    def train_beam_decode(self,
                          inpute,
                          mask=None,
                          beam_size=8,
                          max_len=512,
                          length_penalty=0.0,
                          return_all=False,
                          clip_beam=clip_beam_with_lp):

        bsize, seql = inpute.size()

        real_bsize = bsize * beam_size

        ence = [
            encu.repeat(1, beam_size, 1).view(real_bsize, seql, -1)
            for encu in self.enc(inpute, mask)
        ]

        mask = mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql)

        # out: input to the decoder for the first step (bsize * beam_size, 1)

        out = inpute.new_full((real_bsize, 1), 3)

        if length_penalty > 0.0:
            # lpv: length penalty vector for each beam (bsize * beam_size, 1)
            lpv = ence[0].new_ones(real_bsize, 1)
            lpv_base = 6.0**length_penalty

        done_trans = None
        scores = None
        sum_scores = None

        beam_size2 = beam_size * beam_size
        bsizeb2 = bsize * beam_size2

        for step in range(1, max_len + 1):

            _out = self.dec(ence, out, mask)

            # _out: (bsize * beam_size, nquery, vocab_size) => (bsize, beam_size, vocab_size)
            _out = _out.narrow(1,
                               _out.size(1) - 1, 1).view(bsize, beam_size, -1)

            # _scores/_wds: (bsize, beam_size, beam_size)
            _scores, _wds = _out.topk(beam_size, dim=-1)

            if done_trans is not None:
                _done_trans_unsqueeze = done_trans.unsqueeze(2)
                _scores = _scores.masked_fill(
                    _done_trans_unsqueeze.expand(bsize, beam_size, beam_size),
                    0.0) + sum_scores.unsqueeze(2).repeat(
                        1, 1, beam_size).masked_fill_(
                            select_zero_(
                                _done_trans_unsqueeze.repeat(1, 1, beam_size),
                                -1, 0), -inf_default)

                if length_penalty > 0.0:
                    lpv = lpv.masked_fill(1 - done_trans.view(real_bsize, 1),
                                          ((step + 5.0)**length_penalty) /
                                          lpv_base)

            # scores/_inds: (bsize, beam_size)
            if clip_beam and (length_penalty > 0.0):
                scores, _inds = (_scores /
                                 lpv.expand(real_bsize, beam_size)).view(
                                     bsize, beam_size2).topk(beam_size, dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)

                # sum_scores: (bsize, beam_size)
                sum_scores = _scores.view(bsizeb2).index_select(
                    0, _tinds).view(bsize, beam_size)

            else:
                scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size,
                                                                     dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = scores

            # wds: (bsize * beam_size, 1)
            wds = _wds.view(bsizeb2).index_select(0,
                                                  _tinds).view(real_bsize, 1)

            _inds = (
                _inds // beam_size +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)
            out = torch.cat((out.index_select(0, _inds), wds), -1)

            # done_trans: (bsize, beam_size)
            done_trans = wds.view(
                bsize, beam_size).eq(2) if done_trans is None else (
                    done_trans.view(real_bsize).index_select(0, _inds)
                    | wds.view(real_bsize).eq(2)).view(bsize, beam_size)

            # check early stop for beam search
            # done_trans: (bsize, beam_size)
            # scores: (bsize, beam_size)

            _done = False
            if length_penalty > 0.0:
                lpv = lpv.index_select(0, _inds)
            elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
                _done = True

            # check beam states(done or not)

            if _done or all_done(done_trans, real_bsize):
                break

        out = out.narrow(1, 1, out.size(1) - 1)

        # if length penalty is only applied in the last step, apply length penalty
        if (not clip_beam) and (length_penalty > 0.0):
            scores = scores / lpv.view(bsize, beam_size)

        if return_all:

            return out.view(bsize, beam_size, -1), scores
        else:

            # out: (bsize * beam_size, nquery) => (bsize, nquery)
            return out.view(bsize, beam_size, -1).select(1, 0)
Ejemplo n.º 5
0
    def beam_decode(self,
                    inpute,
                    inputh,
                    src_pad_mask=None,
                    chk_pad_mask=None,
                    beam_size=8,
                    max_len=512,
                    length_penalty=0.0,
                    return_all=False,
                    clip_beam=clip_beam_with_lp,
                    fill_pad=False):

        bsize, seql = inpute.size()[:2]

        beam_size2 = beam_size * beam_size
        bsizeb2 = bsize * beam_size2
        real_bsize = bsize * beam_size

        sos_emb = self.get_sos_emb(inpute)
        isize = sos_emb.size(-1)
        sqrt_isize = sqrt(isize)

        if length_penalty > 0.0:
            lpv = sos_emb.new_ones(real_bsize, 1)
            lpv_base = 6.0**length_penalty

        out = sos_emb * sqrt_isize
        if self.pemb is not None:
            out = out + self.pemb.get_pos(0)

        if self.drop is not None:
            out = self.drop(out)

        out = self.out_normer(out)

        states = {}

        for _tmp, (net, inputu, inputhu) in enumerate(
                zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))):
            out, _state = net(inputu, inputhu, None, src_pad_mask,
                              chk_pad_mask, None, out, True)
            states[_tmp] = _state

        out = self.lsm(self.classifier(out))

        scores, wds = out.topk(beam_size, dim=-1)
        scores = scores.squeeze(1)
        sum_scores = scores
        wds = wds.view(real_bsize, 1)
        trans = wds

        done_trans = wds.view(bsize, beam_size).eq(2)

        #inputh = repeat_bsize_for_beam_tensor(inputh, beam_size)
        self.repeat_cross_attn_buffer(beam_size)

        _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(
            1, beam_size, 1).view(real_bsize, 1, seql)
        _chk_pad_mask = None if chk_pad_mask is None else repeat_bsize_for_beam_tensor(
            chk_pad_mask, beam_size)

        states = expand_bsize_for_beam(states, beam_size=beam_size)

        for step in range(1, max_len):

            out = self.wemb(wds) * sqrt_isize
            if self.pemb is not None:
                out = out + self.pemb.get_pos(step)

            if self.drop is not None:
                out = self.drop(out)

            out = self.out_normer(out)

            for _tmp, (net, inputu, inputhu) in enumerate(
                    zip(self.nets, inpute.unbind(dim=-1),
                        inputh.unbind(dim=-1))):
                out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask,
                                  _chk_pad_mask, None, out, True)
                states[_tmp] = _state

            out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1)

            _scores, _wds = out.topk(beam_size, dim=-1)
            _done_trans_unsqueeze = done_trans.unsqueeze(2)
            _scores = (
                _scores.masked_fill(
                    _done_trans_unsqueeze.expand(bsize, beam_size, beam_size),
                    0.0) +
                sum_scores.unsqueeze(2).repeat(1, 1, beam_size).masked_fill_(
                    select_zero_(_done_trans_unsqueeze.repeat(1, 1, beam_size),
                                 -1, 0), -inf_default))

            if length_penalty > 0.0:
                lpv.masked_fill_(~done_trans.view(real_bsize, 1),
                                 ((step + 6.0)**length_penalty) / lpv_base)

            if clip_beam and (length_penalty > 0.0):
                scores, _inds = (_scores.view(real_bsize, beam_size) /
                                 lpv.expand(real_bsize, beam_size)).view(
                                     bsize, beam_size2).topk(beam_size, dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = _scores.view(bsizeb2).index_select(
                    0, _tinds).view(bsize, beam_size)
            else:
                scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size,
                                                                     dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = scores

            wds = _wds.view(bsizeb2).index_select(0,
                                                  _tinds).view(real_bsize, 1)

            _inds = (
                _inds // beam_size +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)

            trans = torch.cat(
                (trans.index_select(0, _inds),
                 wds.masked_fill(done_trans.view(real_bsize, 1), pad_id)
                 if fill_pad else wds), 1)

            done_trans = (done_trans.view(real_bsize).index_select(0, _inds)
                          | wds.eq(2).squeeze(1)).view(bsize, beam_size)

            _done = False
            if length_penalty > 0.0:
                lpv = lpv.index_select(0, _inds)
            elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
                _done = True

            if _done or all_done(done_trans, real_bsize):
                break

            states = index_tensors(states, indices=_inds, dim=0)

        if (not clip_beam) and (length_penalty > 0.0):
            scores = scores / lpv.view(bsize, beam_size)
            scores, _inds = scores.topk(beam_size, dim=-1)
            _inds = (
                _inds +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)
            trans = trans.view(real_bsize, -1).index_select(0, _inds)

        if return_all:

            return trans.view(bsize, beam_size, -1), scores
        else:

            return trans.view(bsize, beam_size, -1).select(1, 0)
Ejemplo n.º 6
0
	def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False):

		bsize = inpute.size(0)

		out = self.get_sos_emb(inpute)

		# out: input to the decoder for the first step (bsize, 1, isize)

		if self.drop is not None:
			out = self.drop(out)

		out, statefl = self.flayer(out, "init", True)

		states = {}

		if self.projector:
			inpute = self.projector(inpute)

		attn = self.attn(out.unsqueeze(1), inpute, src_pad_mask).squeeze(1)

		for _tmp, net in enumerate(self.nets):
			out, _state = net(out, attn, "init", True)
			states[_tmp] = _state

		if self.out_normer is not None:
			out = self.out_normer(out)

		# out: (bsize, nwd)
		out = self.classifier(torch.cat((out, attn), -1))
		# wds: (bsize)
		wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

		trans = [wds]

		# done_trans: (bsize)

		done_trans = wds.eq(2)

		for i in range(1, max_len):

			out = self.wemb(wds)

			if self.drop is not None:
				out = self.drop(out)

			out, statefl = self.flayer(out, statefl)

			attn = self.attn(out.unsqueeze(1), inpute, src_pad_mask).squeeze(1)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(out, attn, states[_tmp])
				states[_tmp] = _state

			if self.out_normer is not None:
				out = self.out_normer(out)

			out = self.classifier(torch.cat((out, attn), -1))
			wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

			trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

			done_trans = done_trans | wds.eq(2)
			if all_done(done_trans, bsize):
				break

		return torch.stack(trans, 1)
Ejemplo n.º 7
0
	def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False):

		bsize, seql = inpute.size()[:2]

		beam_size2 = beam_size * beam_size
		bsizeb2 = bsize * beam_size2
		real_bsize = bsize * beam_size

		out = self.get_sos_emb(inpute)
		isize = out.size(-1)

		if length_penalty > 0.0:
			# lpv: length penalty vector for each beam (bsize * beam_size, 1)
			lpv = out.new_ones(real_bsize, 1)
			lpv_base = 6.0 ** length_penalty

		if self.drop is not None:
			out = self.drop(out)

		out, statefl = self.flayer(out, "init", True)
		statefl = torch.stack(statefl, -2)

		states = {}

		if self.projector:
			inpute = self.projector(inpute)

		attn = self.attn(out.unsqueeze(1), inpute, src_pad_mask).squeeze(1)

		for _tmp, net in enumerate(self.nets):
			out, _state = net(out, attn, "init", True)
			states[_tmp] = torch.stack(_state, -2)

		if self.out_normer is not None:
			out = self.out_normer(out)

		# out: (bsize, nwd)

		out = self.lsm(self.classifier(torch.cat((out, attn), -1)))

		# scores: (bsize, beam_size) => (bsize, beam_size)
		# wds: (bsize * beam_size)
		# trans: (bsize * beam_size, 1)

		scores, wds = out.topk(beam_size, dim=-1)
		sum_scores = scores
		wds = wds.view(real_bsize)
		trans = wds.unsqueeze(1)

		# done_trans: (bsize, beam_size)

		done_trans = wds.view(bsize, beam_size).eq(2)

		# inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize)

		self.repeat_cross_attn_buffer(beam_size)

		# _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql)

		_src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql)

		# states[i]: (bsize, 2, isize) => (bsize * beam_size, 2, isize)

		statefl = statefl.repeat(1, beam_size, 1).view(real_bsize, 2, isize)
		states = expand_bsize_for_beam(states, beam_size=beam_size)

		for step in range(1, max_len):

			out = self.wemb(wds)

			if self.drop is not None:
				out = self.drop(out)

			out, statefl = self.flayer(out, statefl.unbind(-2))
			statefl = torch.stack(statefl, -2)

			attn = self.attn(out.unsqueeze(1), inpute, _src_pad_mask).squeeze(1)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(out, attn, states[_tmp].unbind(-2))
				states[_tmp] = torch.stack(_state, -2)

			if self.out_normer is not None:
				out = self.out_normer(out)

			# out: (bsize, beam_size, nwd)

			out = self.lsm(self.classifier(torch.cat((out, attn), -1))).view(bsize, beam_size, -1)

			# find the top k ** 2 candidates and calculate route scores for them
			# _scores: (bsize, beam_size, beam_size)
			# done_trans: (bsize, beam_size)
			# scores: (bsize, beam_size)
			# _wds: (bsize, beam_size, beam_size)
			# mask_from_done_trans: (bsize, beam_size) => (bsize, beam_size * beam_size)
			# added_scores: (bsize, 1, beam_size) => (bsize, beam_size, beam_size)

			_scores, _wds = out.topk(beam_size, dim=-1)
			_done_trans_unsqueeze = done_trans.unsqueeze(2)
			_scores = (_scores.masked_fill(_done_trans_unsqueeze.expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).repeat(1, 1, beam_size).masked_fill_(select_zero_(_done_trans_unsqueeze.repeat(1, 1, beam_size), -1, 0), -inf_default))

			if length_penalty > 0.0:
				lpv.masked_fill_(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base)

			# clip from k ** 2 candidate and remain the top-k for each path
			# scores: (bsize, beam_size * beam_size) => (bsize, beam_size)
			# _inds: indexes for the top-k candidate (bsize, beam_size)

			if clip_beam and (length_penalty > 0.0):
				scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1)
				_tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
				sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size)
			else:
				scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1)
				_tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
				sum_scores = scores

			# select the top-k candidate with higher route score and update translation record
			# wds: (bsize, beam_size, beam_size) => (bsize * beam_size)

			wds = _wds.view(bsizeb2).index_select(0, _tinds)

			# reduces indexes in _inds from (beam_size ** 2) to beam_size
			# thus the fore path of the top-k candidate is pointed out
			# _inds: indexes for the top-k candidate (bsize, beam_size)

			_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

			# select the corresponding translation history for the top-k candidate and update translation records
			# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)

			trans = torch.cat((trans.index_select(0, _inds), (wds.masked_fill(done_trans.view(real_bsize), pad_id) if fill_pad else wds).unsqueeze(1)), 1)

			done_trans = (done_trans.view(real_bsize).index_select(0, _inds) & wds.eq(2)).view(bsize, beam_size)

			# check early stop for beam search
			# done_trans: (bsize, beam_size)
			# scores: (bsize, beam_size)

			_done = False
			if length_penalty > 0.0:
				lpv = lpv.index_select(0, _inds)
			elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
				_done = True

			# check beam states(done or not)

			if _done or all_done(done_trans, real_bsize):
				break

			# update the corresponding hidden states
			# states[i]: (bsize * beam_size, 2, isize)
			# _inds: (bsize, beam_size) => (bsize * beam_size)

			statefl = statefl.index_select(0, _inds)
			states = index_tensors(states, indices=_inds, dim=0)

		# if length penalty is only applied in the last step, apply length penalty
		if (not clip_beam) and (length_penalty > 0.0):
			scores = scores / lpv.view(bsize, beam_size)
			scores, _inds = scores.topk(beam_size, dim=-1)
			_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
			trans = trans.view(real_bsize, -1).index_select(0, _inds)

		if return_all:

			return trans.view(bsize, beam_size, -1), scores
		else:

			return trans.view(bsize, beam_size, -1).select(1, 0)
Ejemplo n.º 8
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize, seql, isize = inpute[0].size()

        sqrt_isize = sqrt(isize)

        outs = []

        for model, inputu in zip(self.nets, inpute):

            sos_emb = model.get_sos_emb(inputu)

            # out: input to the decoder for the first step (bsize, 1, isize)

            out = sos_emb * sqrt_isize
            if model.pemb is not None:
                out = out + model.pemb.get_pos(0)

            if model.drop is not None:
                out = model.drop(out)

            states = {}

            for _tmp, net in enumerate(model.nets):
                out, _state = net(inputu, None, src_pad_mask, None, out)
                states[_tmp] = _state

            if model.out_normer is not None:
                out = model.out_normer(out)

            # outs: [(bsize, 1, nwd)...]

            outs.append(model.classifier(out).softmax(dim=-1))

        out = torch.stack(outs).mean(0)
        wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(
            dim=-1)

        trans = [wds]

        # done_trans: (bsize, 1)

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            outs = []

            for model, inputu in zip(self.nets, inpute):

                out = model.wemb(wds) * sqrt_isize
                if model.pemb is not None:
                    out = out + model.pemb.get_pos(i)

                if model.drop is not None:
                    out = model.drop(out)

                for _tmp, net in enumerate(model.nets):
                    out, _state = net(inputu, states[_tmp], src_pad_mask, None,
                                      out)
                    states[_tmp] = _state

                if model.out_normer is not None:
                    out = model.out_normer(out)

                # outs: [(bsize, 1, nwd)...]
                outs.append(model.classifier(out).softmax(dim=-1))

            out = torch.stack(outs).mean(0)
            wds = SampleMax(out, dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)
Ejemplo n.º 9
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize = inpute.size(0)

        sos_emb = self.get_sos_emb(inpute)

        sqrt_isize = sqrt(sos_emb.size(-1))

        # out: input to the decoder for the first step (bsize, 1, isize)

        out = sos_emb * sqrt_isize
        if self.pemb is not None:
            out = out + self.pemb.get_pos(0)

        if self.drop is not None:
            out = self.drop(out)

        states = {}

        for _tmp, (net,
                   inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))):
            out, _state = net(inputu, None, src_pad_mask, None, out, True)
            states[_tmp] = _state

        if self.out_normer is not None:
            out = self.out_normer(out)

        out = self.classifier(out)
        wds = SampleMax(out.softmax(-1), dim=-1,
                        keepdim=False) if sample else out.argmax(dim=-1)

        trans = [wds]

        # done_trans: (bsize, 1)

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            out = self.wemb(wds) * sqrt_isize
            if self.pemb is not None:
                out = out + self.pemb.get_pos(i)

            if self.drop is not None:
                out = self.drop(out)

            for _tmp, (net, inputu) in enumerate(
                    zip(self.nets, inpute.unbind(dim=-1))):
                out, _state = net(inputu, states[_tmp], src_pad_mask, None,
                                  out, True)
                states[_tmp] = _state

            if self.out_normer is not None:
                out = self.out_normer(out)

            # out: (bsize, 1, nwd)
            out = self.classifier(out)
            wds = SampleMax(out.softmax(-1), dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)
Ejemplo n.º 10
0
    def beam_decode(self,
                    inpute,
                    inputm,
                    src_pad_mask=None,
                    mt_pad_mask=None,
                    beam_size=8,
                    max_len=512,
                    length_penalty=0.0,
                    return_all=False,
                    clip_beam=False,
                    fill_pad=False):

        bsize, seql = inpute.size()[:2]
        mtl = inputm.size(1)

        beam_size2 = beam_size * beam_size
        bsizeb2 = bsize * beam_size2
        real_bsize = bsize * beam_size

        sos_emb = self.get_sos_emb(inpute)
        isize = sos_emb.size(-1)
        sqrt_isize = sqrt(isize)

        if length_penalty > 0.0:
            lpv = sos_emb.new_ones(real_bsize, 1)
            lpv_base = 6.0**length_penalty

        out = sos_emb * sqrt_isize
        if self.pemb is not None:
            out = out + self.pemb.get_pos(0)

        if self.drop is not None:
            out = self.drop(out)

        states = {}

        for _tmp, net in enumerate(self.nets):
            out, _state = net(inpute, inputm, None, src_pad_mask, mt_pad_mask,
                              None, out)
            states[_tmp] = _state

        if self.out_normer is not None:
            out = self.out_normer(out)

        out = self.lsm(self.classifier(out))

        scores, wds = out.topk(beam_size, dim=-1)
        scores = scores.squeeze(1)
        sum_scores = scores
        wds = wds.view(real_bsize, 1)
        trans = wds

        done_trans = wds.view(bsize, beam_size).eq(2)

        inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize)
        inputm = inputm.repeat(1, beam_size, 1).view(real_bsize, mtl, isize)

        _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(
            1, beam_size, 1).view(real_bsize, 1, seql)
        _mt_pad_mask = None if mt_pad_mask is None else mt_pad_mask.repeat(
            1, beam_size, 1).view(real_bsize, 1, mtl)

        for key, value in states.items():
            states[key] = repeat_bsize_for_beam_tensor(value, beam_size)

        for step in range(1, max_len):

            out = self.wemb(wds) * sqrt_isize
            if self.pemb is not None:
                out = out + self.pemb.get_pos(step)

            if self.drop is not None:
                out = self.drop(out)

            for _tmp, net in enumerate(self.nets):
                out, _state = net(inpute, inputm, states[_tmp], _src_pad_mask,
                                  _mt_pad_mask, None, out)
                states[_tmp] = _state

            if self.out_normer is not None:
                out = self.out_normer(out)

            out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1)

            _scores, _wds = out.topk(beam_size, dim=-1)
            _scores = (
                _scores.masked_fill(
                    done_trans.unsqueeze(2).expand(bsize, beam_size,
                                                   beam_size), 0.0) +
                sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size))

            if length_penalty > 0.0:
                lpv = lpv.masked_fill(~done_trans.view(real_bsize, 1),
                                      ((step + 6.0)**length_penalty) /
                                      lpv_base)

            if clip_beam and (length_penalty > 0.0):
                scores, _inds = (_scores.view(real_bsize, beam_size) /
                                 lpv.expand(real_bsize, beam_size)).view(
                                     bsize, beam_size2).topk(beam_size, dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = _scores.view(bsizeb2).index_select(
                    0, _tinds).view(bsize, beam_size)
            else:
                scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size,
                                                                     dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = scores

            wds = _wds.view(bsizeb2).index_select(0,
                                                  _tinds).view(real_bsize, 1)

            _inds = (
                _inds // beam_size +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)

            trans = torch.cat(
                (trans.index_select(0, _inds),
                 wds.masked_fill(done_trans.view(real_bsize, 1), pad_id)
                 if fill_pad else wds), 1)

            done_trans = (done_trans.view(real_bsize).index_select(0, _inds)
                          | wds.eq(2).squeeze(1)).view(bsize, beam_size)

            _done = False
            if length_penalty > 0.0:
                lpv = lpv.index_select(0, _inds)
            elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
                _done = True

            if _done or all_done(done_trans, real_bsize):
                break

            for key, value in states.items():
                states[key] = value.index_select(0, _inds)

        if (not clip_beam) and (length_penalty > 0.0):
            scores = scores / lpv.view(bsize, beam_size)
            scores, _inds = scores.topk(beam_size, dim=-1)
            _inds = (
                _inds +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)
            trans = trans.view(real_bsize, -1).index_select(0, _inds).view(
                bsize, beam_size, -1)

        if return_all:

            return trans, scores
        else:

            return trans.view(bsize, beam_size, -1).select(1, 0)
Ejemplo n.º 11
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize = inpute.size(0)

        out = self.get_sos_emb(inpute)

        if self.drop is not None:
            out = self.drop(out)

        states = {}

        for _tmp, net in enumerate(self.nets):
            out, _state = net(inpute,
                              "init",
                              src_pad_mask=src_pad_mask,
                              query_unit=out)
            states[_tmp] = _state

        if self.out_normer is not None:
            out = self.out_normer(out)

        out = self.classifier(out)
        wds = SampleMax(out.softmax(-1), dim=-1,
                        keepdim=False) if sample else out.argmax(dim=-1)

        trans = [wds]

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            out = self.wemb(wds)

            if self.drop is not None:
                out = self.drop(out)

            for _tmp, net in enumerate(self.nets):
                out, _state = net(inpute,
                                  states[_tmp],
                                  src_pad_mask=src_pad_mask,
                                  query_unit=out)
                states[_tmp] = _state

            if self.out_normer is not None:
                out = self.out_normer(out)

            out = self.classifier(out)
            wds = SampleMax(out.softmax(-1), dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)
Ejemplo n.º 12
0
    def beam_decode(self,
                    inpute,
                    src_pad_mask=None,
                    beam_size=8,
                    max_len=512,
                    length_penalty=0.0,
                    return_all=False,
                    clip_beam=clip_beam_with_lp,
                    fill_pad=False):

        bsize, seql, isize = inpute[0].size()

        beam_size2 = beam_size * beam_size
        bsizeb2 = bsize * beam_size2
        real_bsize = bsize * beam_size

        sqrt_isize = sqrt(isize)

        if length_penalty > 0.0:
            # lpv: length penalty vector for each beam (bsize * beam_size, 1)
            lpv = inpute[0].new_ones(real_bsize, 1)
            lpv_base = 6.0**length_penalty

        states = {}

        outs = []

        for _inum, (model, inputu) in enumerate(zip(self.nets, inpute)):

            out = model.get_sos_emb(inputu) * sqrt_isize
            if model.pemb is not None:
                out = out + model.pemb.get_pos(0)

            if model.drop is not None:
                out = model.drop(out)

            states[_inum] = {}

            for _tmp, net in enumerate(model.nets):
                out, _state = net(inputu, None, src_pad_mask, out, 1)
                states[_inum][_tmp] = _state

            if model.out_normer is not None:
                out = model.out_normer(out)

            # outs: [(bsize, 1, nwd)]

            outs.append(model.classifier(out).softmax(dim=-1))

        out = torch.stack(outs).mean(0).log()

        # scores: (bsize, 1, beam_size) => (bsize, beam_size)
        # wds: (bsize * beam_size, 1)
        # trans: (bsize * beam_size, 1)

        scores, wds = out.topk(beam_size, dim=-1)
        scores = scores.squeeze(1)
        sum_scores = scores
        wds = wds.view(real_bsize, 1)
        trans = wds

        # done_trans: (bsize, beam_size)

        done_trans = wds.view(bsize, beam_size).eq(2)

        # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize)

        inpute = [
            inputu.repeat(1, beam_size, 1).view(real_bsize, seql, isize)
            for inputu in inpute
        ]

        # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql)

        _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(
            1, beam_size, 1).view(real_bsize, 1, seql)

        # states[i][j]: (bsize, 1, isize) => (bsize * beam_size, 1, isize)

        states = expand_bsize_for_beam(states, beam_size=beam_size)

        for step in range(2, max_len + 1):

            outs = []

            for _inum, (model, inputu) in enumerate(zip(self.nets, inpute)):

                out = model.wemb(wds) * sqrt_isize
                if model.pemb is not None:
                    out = out + model.pemb.get_pos(step - 1)

                if model.drop is not None:
                    out = model.drop(out)

                for _tmp, net in enumerate(model.nets):
                    out, _state = net(inputu, states[_inum][_tmp],
                                      _src_pad_mask, out, step)
                    states[_inum][_tmp] = _state

                if model.out_normer is not None:
                    out = model.out_normer(out)

                # outs: [(bsize, beam_size, nwd)...]

                outs.append(
                    model.classifier(out).softmax(dim=-1).view(
                        bsize, beam_size, -1))

            out = torch.stack(outs).mean(0).log()

            # find the top k ** 2 candidates and calculate route scores for them
            # _scores: (bsize, beam_size, beam_size)
            # done_trans: (bsize, beam_size)
            # scores: (bsize, beam_size)
            # _wds: (bsize, beam_size, beam_size)
            # mask_from_done_trans: (bsize, beam_size) => (bsize, beam_size * beam_size)
            # added_scores: (bsize, 1, beam_size) => (bsize, beam_size, beam_size)

            _scores, _wds = out.topk(beam_size, dim=-1)
            _scores = (_scores.masked_fill(
                done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size),
                0.0) + scores.unsqueeze(2).expand(bsize, beam_size, beam_size))

            if length_penalty > 0.0:
                lpv = lpv.masked_fill(1 - done_trans.view(real_bsize, 1),
                                      ((step + 5.0)**length_penalty) /
                                      lpv_base)

            # clip from k ** 2 candidate and remain the top-k for each path
            # scores: (bsize, beam_size * beam_size) => (bsize, beam_size)
            # _inds: indexes for the top-k candidate (bsize, beam_size)

            if clip_beam and (length_penalty > 0.0):
                scores, _inds = (_scores.view(real_bsize, beam_size) /
                                 lpv.expand(real_bsize, beam_size)).view(
                                     bsize, beam_size2).topk(beam_size, dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = _scores.view(bsizeb2).index_select(
                    0, _tinds).view(bsize, beam_size)
            else:
                scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size,
                                                                     dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = scores

            # select the top-k candidate with higher route score and update translation record
            # wds: (bsize, beam_size, beam_size) => (bsize * beam_size, 1)

            wds = _wds.view(bsizeb2).index_select(0,
                                                  _tinds).view(real_bsize, 1)

            # reduces indexes in _inds from (beam_size ** 2) to beam_size
            # thus the fore path of the top-k candidate is pointed out
            # _inds: indexes for the top-k candidate (bsize, beam_size)

            _inds = (
                _inds // beam_size +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)

            # select the corresponding translation history for the top-k candidate and update translation records
            # trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)

            trans = torch.cat(
                (trans.index_select(0, _inds),
                 wds.masked_fill(done_trans.view(real_bsize, 1), pad_id)
                 if fill_pad else wds), 1)

            done_trans = (done_trans.view(real_bsize).index_select(0, _inds)
                          | wds.eq(2).squeeze(1)).view(bsize, beam_size)

            # check early stop for beam search
            # done_trans: (bsize, beam_size)
            # scores: (bsize, beam_size)

            _done = False
            if length_penalty > 0.0:
                lpv = lpv.index_select(0, _inds)
            elif (not return_all) and all_done(done_trans.select(1, 0), bsize):
                _done = True

            # check beam states(done or not)

            if _done or all_done(done_trans, real_bsize):
                break

            # update the corresponding hidden states
            # states[i][j]: (bsize * beam_size, nquery, isize)
            # _inds: (bsize, beam_size) => (bsize * beam_size)

            states = index_tensors(states, indices=_inds, dim=0)

        # if length penalty is only applied in the last step, apply length penalty
        if (not clip_beam) and (length_penalty > 0.0):
            scores = scores / lpv.view(bsize, beam_size)
            scores, _inds = scores.topk(beam_size, dim=-1)
            _inds = (
                _inds +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)
            trans = trans.view(real_bsize, -1).index_select(0, _inds)

        if return_all:

            return trans.view(bsize, beam_size, -1), scores
        else:

            return trans.view(bsize, beam_size, -1).select(1, 0)