예제 #1
0
	def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False, sample=False):

		bsize = inpute.size(0)

		sos_emb = self.get_sos_emb(inpute)

		sqrt_isize = sqrt(sos_emb.size(-1))

		out = sos_emb * sqrt_isize
		if self.pemb is not None:
			 out = out + self.pemb.get_pos(0)

		if self.drop is not None:
			out = self.drop(out)

		states = {}

		for _tmp, net in enumerate(self.nets):
			out, _state = net(inpute, None, inputc, src_pad_mask, context_mask, None, out)
			states[_tmp] = _state

		if self.out_normer is not None:
			out = self.out_normer(out)

		out = self.classifier(out)
		wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

		trans = [wds]

		done_trans = wds.eq(2)

		for i in range(1, max_len):

			out = self.wemb(wds) * sqrt_isize
			if self.pemb is not None:
				out = out + self.pemb.get_pos(i)

			if self.drop is not None:
				out = self.drop(out)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(inpute, states[_tmp], inputc, src_pad_mask, None, context_mask, out)
				states[_tmp] = _state

			if self.out_normer is not None:
				out = self.out_normer(out)

			out = self.classifier(out)
			wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

			trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds)

			done_trans = done_trans | wds.eq(2)
			if all_done(done_trans, bsize):
				break

		return torch.cat(trans, 1)
예제 #2
0
	def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False):

		bsize = inpute.size(0)

		out = self.get_sos_emb(inpute)

		# out: input to the decoder for the first step (bsize, 1, isize)

		if self.drop is not None:
			out = self.drop(out)

		out, statefl = self.flayer(out, "init", True)

		states = {}

		if self.projector:
			inpute = self.projector(inpute)

		attn = self.attn(out.unsqueeze(1), inpute, src_pad_mask).squeeze(1)

		for _tmp, net in enumerate(self.nets):
			out, _state = net(out, attn, "init", True)
			states[_tmp] = _state

		if self.out_normer is not None:
			out = self.out_normer(out)

		# out: (bsize, nwd)
		out = self.classifier(torch.cat((out, attn), -1))
		# wds: (bsize)
		wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

		trans = [wds]

		# done_trans: (bsize)

		done_trans = wds.eq(2)

		for i in range(1, max_len):

			out = self.wemb(wds)

			if self.drop is not None:
				out = self.drop(out)

			out, statefl = self.flayer(out, statefl)

			attn = self.attn(out.unsqueeze(1), inpute, src_pad_mask).squeeze(1)

			for _tmp, net in enumerate(self.nets):
				out, _state = net(out, attn, states[_tmp])
				states[_tmp] = _state

			if self.out_normer is not None:
				out = self.out_normer(out)

			out = self.classifier(torch.cat((out, attn), -1))
			wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1)

			trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

			done_trans = done_trans | wds.eq(2)
			if all_done(done_trans, bsize):
				break

		return torch.stack(trans, 1)
예제 #3
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize, seql, isize = inpute[0].size()

        sqrt_isize = sqrt(isize)

        outs = []

        for model, inputu in zip(self.nets, inpute):

            sos_emb = model.get_sos_emb(inputu)

            # out: input to the decoder for the first step (bsize, 1, isize)

            out = sos_emb * sqrt_isize
            if model.pemb is not None:
                out = out + model.pemb.get_pos(0)

            if model.drop is not None:
                out = model.drop(out)

            states = {}

            for _tmp, net in enumerate(model.nets):
                out, _state = net(inputu, None, src_pad_mask, None, out)
                states[_tmp] = _state

            if model.out_normer is not None:
                out = model.out_normer(out)

            # outs: [(bsize, 1, nwd)...]

            outs.append(model.classifier(out).softmax(dim=-1))

        out = torch.stack(outs).mean(0)
        wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(
            dim=-1)

        trans = [wds]

        # done_trans: (bsize, 1)

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            outs = []

            for model, inputu in zip(self.nets, inpute):

                out = model.wemb(wds) * sqrt_isize
                if model.pemb is not None:
                    out = out + model.pemb.get_pos(i)

                if model.drop is not None:
                    out = model.drop(out)

                for _tmp, net in enumerate(model.nets):
                    out, _state = net(inputu, states[_tmp], src_pad_mask, None,
                                      out)
                    states[_tmp] = _state

                if model.out_normer is not None:
                    out = model.out_normer(out)

                # outs: [(bsize, 1, nwd)...]
                outs.append(model.classifier(out).softmax(dim=-1))

            out = torch.stack(outs).mean(0)
            wds = SampleMax(out, dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)
예제 #4
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize = inpute.size(0)

        sos_emb = self.get_sos_emb(inpute)

        sqrt_isize = sqrt(sos_emb.size(-1))

        # out: input to the decoder for the first step (bsize, 1, isize)

        out = sos_emb * sqrt_isize
        if self.pemb is not None:
            out = out + self.pemb.get_pos(0)

        if self.drop is not None:
            out = self.drop(out)

        states = {}

        for _tmp, (net,
                   inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))):
            out, _state = net(inputu, None, src_pad_mask, None, out, True)
            states[_tmp] = _state

        if self.out_normer is not None:
            out = self.out_normer(out)

        out = self.classifier(out)
        wds = SampleMax(out.softmax(-1), dim=-1,
                        keepdim=False) if sample else out.argmax(dim=-1)

        trans = [wds]

        # done_trans: (bsize, 1)

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            out = self.wemb(wds) * sqrt_isize
            if self.pemb is not None:
                out = out + self.pemb.get_pos(i)

            if self.drop is not None:
                out = self.drop(out)

            for _tmp, (net, inputu) in enumerate(
                    zip(self.nets, inpute.unbind(dim=-1))):
                out, _state = net(inputu, states[_tmp], src_pad_mask, None,
                                  out, True)
                states[_tmp] = _state

            if self.out_normer is not None:
                out = self.out_normer(out)

            # out: (bsize, 1, nwd)
            out = self.classifier(out)
            wds = SampleMax(out.softmax(-1), dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)
예제 #5
0
    def greedy_decode(self,
                      inpute,
                      src_pad_mask=None,
                      max_len=512,
                      fill_pad=False,
                      sample=False):

        bsize = inpute.size(0)

        out = self.get_sos_emb(inpute)

        if self.drop is not None:
            out = self.drop(out)

        states = {}

        for _tmp, net in enumerate(self.nets):
            out, _state = net(inpute,
                              "init",
                              src_pad_mask=src_pad_mask,
                              query_unit=out)
            states[_tmp] = _state

        if self.out_normer is not None:
            out = self.out_normer(out)

        out = self.classifier(out)
        wds = SampleMax(out.softmax(-1), dim=-1,
                        keepdim=False) if sample else out.argmax(dim=-1)

        trans = [wds]

        done_trans = wds.eq(2)

        for i in range(1, max_len):

            out = self.wemb(wds)

            if self.drop is not None:
                out = self.drop(out)

            for _tmp, net in enumerate(self.nets):
                out, _state = net(inpute,
                                  states[_tmp],
                                  src_pad_mask=src_pad_mask,
                                  query_unit=out)
                states[_tmp] = _state

            if self.out_normer is not None:
                out = self.out_normer(out)

            out = self.classifier(out)
            wds = SampleMax(out.softmax(-1), dim=-1,
                            keepdim=False) if sample else out.argmax(dim=-1)

            trans.append(
                wds.masked_fill(done_trans, pad_id) if fill_pad else wds)

            done_trans = done_trans | wds.eq(2)
            if all_done(done_trans, bsize):
                break

        return torch.cat(trans, 1)