示例#1
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=False,
                 forbidden_index=None):

        super(NMT, self).__init__()

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop,
                           num_head, xseql, ahsize, norm_output)

        emb_w = self.enc.wemb.weight if global_emb else None

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index)
        #self.dec = Decoder(isize, tnwd, dec_layer, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index)# for RNMT

        if rel_pos_enabled:
            share_rel_pos_cache(self)
示例#2
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=512,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=False,
                 forbidden_index=None):

        super(NMT, self).__init__()

        self.enc = Encoder(isize, snwd, num_layer, fhsize, dropout, attn_drop,
                           num_head, xseql, ahsize, norm_output)

        emb_w = self.enc.wemb.weight if global_emb else None

        self.dec = Decoder(isize, tnwd, num_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index)
示例#3
0
 def __init__(self, vocab_size, isize=512, hsize=512, output_size=1):
     super(TwoTextClassifier, self).__init__()
     self.encoder = Encoder(isize=isize,
                            ahsize=hsize,
                            num_layer=6,
                            nwd=vocab_size)
     self.linear = nn.Linear(isize * 2, 1)
     self.linear2 = nn.Linear(isize, 1)
示例#4
0
    def __init__(self, vocab_size, isize=512, hsize=512):
        super(SquadQA, self).__init__()

        self.encoder = Encoder(isize=isize,
                               ahsize=hsize,
                               num_layer=6,
                               nwd=vocab_size)
        self.linear = nn.Linear(isize, isize)
        self.linear_endpos = nn.Linear(isize, isize)
        self.linear2 = nn.Linear(
            2 * isize, isize)  # used to obtain start-position informed cls rep
        '''TODO'''
示例#5
0
	def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, num_layer_context=1):

		super(Encoder, self).__init__()

		_ahsize = isize if ahsize is None else ahsize

		_fhsize = _ahsize * 4 if fhsize is None else fhsize

		self.context_enc = EncoderBase(isize, nwd, num_layer if num_layer_context is None else num_layer_context, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output)
		self.enc = CrossEncoder(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output, nprev_context)

		_tmp_pad = torch.zeros(xseql, dtype=torch.long)
		_tmp_pad[0] = 1
		_tmp_pad = _tmp_pad.view(1, 1, xseql).repeat(1, nprev_context - 1, 1)
		self.register_buffer('pad', _tmp_pad)
		self.register_buffer('pad_mask', (1 - _tmp_pad).to(mask_tensor_type).unsqueeze(1))
		self.xseql = xseql

		self.nprev_context = nprev_context
示例#6
0
class NMT(nn.Module):

    # isize: size of word embedding
    # snwd: number of words for Encoder
    # tnwd: number of words for Decoder
    # num_layer: number of encoder layers
    # fhsize: number of hidden units for PositionwiseFeedForward
    # attn_drop: dropout for MultiHeadAttention
    # global_emb: Sharing the embedding between encoder and decoder, which means you should have a same vocabulary for source and target language
    # num_head: number of heads in MultiHeadAttention
    # xseql: maxmimum length of sequence
    # ahsize: number of hidden units for MultiHeadAttention

    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=False,
                 forbidden_index=None):

        super(NMT, self).__init__()

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop,
                           num_head, xseql, ahsize, norm_output)

        emb_w = self.enc.wemb.weight if global_emb else None

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index)
        #self.dec = Decoder(isize, tnwd, dec_layer, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index)# for RNMT

        if rel_pos_enabled:
            share_rel_pos_cache(self)

    # inpute: source sentences from encoder (bsize, seql)
    # inputo: decoded translation (bsize, nquery)
    # mask: user specified mask, otherwise it will be:
    #	inpute.eq(0).unsqueeze(1)

    def forward(self, inpute, inputo, mask=None):

        _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask

        return self.dec(self.enc(inpute, _mask), inputo, _mask)

    def load_base(self, base_nmt):

        if "load_base" in dir(self.enc):
            self.enc.load_base(base_nmt.enc)
        else:
            self.enc = base_nmt.enc
        if "load_base" in dir(self.dec):
            self.dec.load_base(base_nmt.dec)
        else:
            self.dec = base_nmt.dec

    # inpute: source sentences from encoder (bsize, seql)
    # beam_size: the beam size for beam search
    # max_len: maximum length to generate

    def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0):

        mask = inpute.eq(0).unsqueeze(1)

        _max_len = inpute.size(1) + max(64,
                                        inpute.size(1) //
                                        4) if max_len is None else max_len

        return self.dec.decode(self.enc(inpute, mask), mask, beam_size,
                               _max_len, length_penalty)

    def train_decode(self,
                     inpute,
                     beam_size=1,
                     max_len=None,
                     length_penalty=0.0,
                     mask=None):

        _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask

        _max_len = inpute.size(1) + max(64,
                                        inpute.size(1) //
                                        4) if max_len is None else max_len

        return self.train_beam_decode(
            inpute, _mask, beam_size, _max_len,
            length_penalty) if beam_size > 1 else self.train_greedy_decode(
                inpute, _mask, _max_len)

    def train_greedy_decode(self, inpute, mask=None, max_len=512):

        ence = self.enc(inpute, mask)

        bsize, _ = inpute.size()

        # out: input to the decoder for the first step (bsize, 1)

        out = inpute.new_ones(bsize, 1)

        done_trans = None

        for i in range(0, max_len):

            _out = self.dec(ence, out, mask)

            _out = _out.argmax(dim=-1)

            wds = _out.narrow(1, _out.size(1) - 1, 1)

            out = torch.cat((out, wds), -1)

            # done_trans: (bsize)
            done_trans = wds.squeeze(1).eq(2) if done_trans is None else (
                done_trans + wds.squeeze(1).eq(2)).gt(0)

            if done_trans.sum().item() == bsize:
                break

        return out.narrow(1, 1, out.size(1) - 1)

    def train_beam_decode(self,
                          inpute,
                          mask=None,
                          beam_size=8,
                          max_len=512,
                          length_penalty=0.0,
                          return_all=False,
                          clip_beam=False):

        bsize, seql = inpute.size()

        real_bsize = bsize * beam_size

        ence = self.enc(inpute, mask).repeat(1, beam_size,
                                             1).view(real_bsize, seql, -1)

        mask = mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql)

        # out: input to the decoder for the first step (bsize * beam_size, 1)

        out = inpute.new_ones(real_bsize, 1)

        if length_penalty > 0.0:
            # lpv: length penalty vector for each beam (bsize * beam_size, 1)
            lpv = ence.new_ones(real_bsize, 1)
            lpv_base = 6.0**length_penalty

        done_trans = None
        scores = None
        sum_scores = None

        beam_size2 = beam_size * beam_size
        bsizeb2 = bsize * beam_size2

        for step in range(1, max_len + 1):

            _out = self.dec(ence, out, mask)

            # _out: (bsize * beam_size, nquery, vocab_size) => (bsize, beam_size, vocab_size)
            _out = _out.narrow(1,
                               _out.size(1) - 1, 1).view(bsize, beam_size, -1)

            # _scores/_wds: (bsize, beam_size, beam_size)
            _scores, _wds = _out.topk(beam_size, dim=-1)

            if done_trans is not None:
                _scores = _scores.masked_fill(
                    done_trans.unsqueeze(2).expand(bsize, beam_size,
                                                   beam_size),
                    0.0) + sum_scores.unsqueeze(2).expand(
                        bsize, beam_size, beam_size)

                if length_penalty > 0.0:
                    lpv = lpv.masked_fill(1 - done_trans.view(real_bsize, 1),
                                          ((step + 5.0)**length_penalty) /
                                          lpv_base)

            # scores/_inds: (bsize, beam_size)
            if clip_beam and (length_penalty > 0.0):
                scores, _inds = (_scores /
                                 lpv.expand(real_bsize, beam_size)).view(
                                     bsize, beam_size2).topk(beam_size, dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)

                # sum_scores: (bsize, beam_size)
                sum_scores = _scores.view(bsizeb2).index_select(
                    0, _tinds).view(bsize, beam_size)

            else:
                scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size,
                                                                     dim=-1)
                _tinds = (_inds + torch.arange(
                    0,
                    bsizeb2,
                    beam_size2,
                    dtype=_inds.dtype,
                    device=_inds.device).unsqueeze(1).expand_as(_inds)
                          ).view(real_bsize)
                sum_scores = scores

            # wds: (bsize * beam_size, 1)
            wds = _wds.view(bsizeb2).index_select(0,
                                                  _tinds).view(real_bsize, 1)

            _inds = (
                _inds / beam_size +
                torch.arange(0,
                             real_bsize,
                             beam_size,
                             dtype=_inds.dtype,
                             device=_inds.device).unsqueeze(1).expand_as(_inds)
            ).view(real_bsize)
            out = torch.cat((out.index_select(0, _inds), wds), -1)

            # done_trans: (bsize, beam_size)
            done_trans = wds.view(
                bsize, beam_size).eq(2) if done_trans is None else (
                    done_trans.view(real_bsize).index_select(0, _inds) +
                    wds.view(real_bsize).eq(2)).gt(0).view(bsize, beam_size)

            # check early stop for beam search
            # done_trans: (bsize, beam_size)
            # scores: (bsize, beam_size)

            _done = False
            if length_penalty > 0.0:
                lpv = lpv.index_select(0, _inds)
            elif (not return_all) and done_trans.select(
                    1, 0).sum().item() == bsize:
                _done = True

            # check beam states(done or not)

            if _done or (done_trans.sum().item() == real_bsize):
                break

        out = out.narrow(1, 1, out.size(1) - 1)

        # if length penalty is only applied in the last step, apply length penalty
        if (not clip_beam) and (length_penalty > 0.0):
            scores = scores / lpv.view(bsize, beam_size)

        if return_all:

            return out.view(bsize, beam_size, -1), scores
        else:

            # out: (bsize * beam_size, nquery) => (bsize, nquery)
            return out.view(bsize, beam_size, -1).select(1, 0)
示例#7
0
class Encoder(nn.Module):

	def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, num_layer_context=1):

		super(Encoder, self).__init__()

		_ahsize = isize if ahsize is None else ahsize

		_fhsize = _ahsize * 4 if fhsize is None else fhsize

		self.context_enc = EncoderBase(isize, nwd, num_layer if num_layer_context is None else num_layer_context, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output)
		self.enc = CrossEncoder(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output, nprev_context)

		_tmp_pad = torch.zeros(xseql, dtype=torch.long)
		_tmp_pad[0] = 1
		_tmp_pad = _tmp_pad.view(1, 1, xseql).repeat(1, nprev_context - 1, 1)
		self.register_buffer('pad', _tmp_pad)
		self.register_buffer('pad_mask', (1 - _tmp_pad).to(mask_tensor_type).unsqueeze(1))
		self.xseql = xseql

		self.nprev_context = nprev_context

	# inputs: (bsize, _nsent, seql), nprev_context, ... , nsent - 1
	# inputc: (bsize, _nsentc, seql), 0, 1, ... , nsent - 2
	# mask: (bsize, 1, _nsent, seql), generated with:
	#	mask = inputs.eq(0).unsqueeze(1)
	# where _nsent = nsent - self.nprev_context, _nsentc = nsent - 1
	def forward(self, inputs, inputc, mask=None, context_mask=None):

		bsize, nsentc, seql = inputc.size()
		_inputc = torch.cat((self.get_pad(seql).expand(bsize, -1, seql), inputc,), dim=1)
		_context_mask = None if context_mask is None else torch.cat((self.get_padmask(seql).expand(bsize, 1, -1, seql), context_mask,), dim=2)
		context = self.context_enc(_inputc.view(-1, seql), mask=None if _context_mask is None else _context_mask.view(-1, 1, seql)).view(bsize, nsentc + self.nprev_context - 1, seql, -1)
		isize = context.size(-1)
		contexts = []
		context_masks = []
		for i in range(self.nprev_context):
			contexts.append(context.narrow(1, i, nsentc).contiguous().view(-1, seql, isize))
			context_masks.append(None if _context_mask is None else _context_mask.narrow(2, i, nsentc).contiguous().view(-1, 1, seql))

		seql = inputs.size(-1)

		return self.enc(inputs.view(-1, seql), contexts, None if mask is None else mask.view(-1, 1, seql), context_masks), contexts, context_masks

	def load_base(self, base_encoder):

		self.enc.load_base(base_encoder)
		with torch.no_grad():
			self.context_enc.wemb.weight.copy_(base_encoder.wemb.weight)

	def get_pad(self, seql):

		return self.pad.narrow(-1, 0, seql) if seql <= self.xseql else torch.cat((self.pad, self.pad.new_zeros(1, self.nprev_context - 1, seql - self.xseql),), dim=-1)

	def get_padmask(self, seql):

		return self.pad_mask.narrow(-1, 0, seql) if seql <= self.xseql else torch.cat((self.pad_mask, self.pad_mask.new_ones(1, 1, self.nprev_context - 1, seql - self.xseql),), dim=-1)

	def update_vocab(self, indices):

		self.context_enc.update_vocab(indices)
示例#8
0
                                            weights_file=opt.weights_file,
                                            model_type=opt.model_type,
                                            src_pad_idx=SRC_PAD_IDX,
                                            input_dropout_p=opt.emb_dropout)
    trg_embedding_layer = TrgEmbeddingLayer(opt.trg_vocab_size,
                                            opt.alphabet_size,
                                            opt.w_embedding_size,
                                            opt.c_embedding_size,
                                            trg_embedding,
                                            options_file=opt.options_file,
                                            weights_file=opt.weights_file,
                                            model_type='word',
                                            trg_pad_idx=TRG_PAD_IDX,
                                            device=device,
                                            input_dropout_p=opt.emb_dropout)
    share_enc = Encoder(opt.embedding_output, opt.hidden_dim, opt.enc_layers,
                        opt.enc_heads, opt.enc_pf_dim, opt.enc_dropout, device)
    norm_enc = Encoder(opt.embedding_output, opt.hidden_dim, opt.enc_layers,
                       opt.enc_heads, opt.enc_pf_dim, opt.enc_dropout, device)
    class_enc = Encoder(opt.embedding_output, opt.hidden_dim, opt.enc_layers,
                        opt.enc_heads, opt.enc_pf_dim, opt.enc_dropout, device)

    dec = Decoder(300, opt.trg_vocab_size, opt.hidden_dim * 2, opt.dec_layers,
                  opt.dec_heads, opt.dec_pf_dim, opt.dec_dropout, device)

    classification = Classification(opt.hidden_dim * 2, opt.num_class)

    SRC_PAD_IDX = input_vocab.word2id['<pad>']
    TGT_PAD_IDX = output_vocab.word2id['<pad>']

    model = MultiTask(src_embedding_layer,
                      trg_embedding_layer,