예제 #1
0
    def __init__(self, input_size, hidden_size, bias=True, dropout_in=0., dropout_rec=0., init_no_retain=True, **kw):
        super(GatedTreeLSTMCell, self).__init__(**kw)
        self.input_size, self.hidden_size, self.usebias = input_size, hidden_size, bias
        self.weight_ih = torch.nn.Parameter(torch.randn(hidden_size * 5, input_size))
        self.weight_hh = torch.nn.Parameter(torch.randn(hidden_size * 5, hidden_size))
        self.bias = torch.nn.Parameter(torch.randn(hidden_size * 5)) if bias is True else None

        # dropouts etc
        self.dropout_in, self.dropout_rec, self.dropout_rec_c = None, None, None
        if dropout_in > 0.:
            self.dropout_in = q.RecDropout(p=dropout_in)
        if dropout_rec > 0.:
            self.dropout_rec = q.RecDropout(p=dropout_rec)
            self.dropout_rec_c = q.RecDropout(p=dropout_rec)
        assert(isinstance(self.dropout_in, (torch.nn.Dropout, q.RecDropout, type(None))))
        assert(isinstance(self.dropout_rec, (torch.nn.Dropout, q.RecDropout, type(None))))
        assert(isinstance(self.dropout_rec_c, (torch.nn.Dropout, q.RecDropout, type(None))))
        self.c_tm1 = None
        self.y_tm1 = None
        self.register_buffer("c_0", torch.zeros(1, self.hidden_size))
        self.register_buffer("y_0", torch.zeros(1, self.hidden_size))

        self.prev_retains = None
        self.prev_retains_ptr = None

        self._init_no_retain = init_no_retain

        self.reset_parameters()
예제 #2
0
 def __init__(self,
              indim,
              dim,
              activation=q.GeLU,
              dropout=0.):  # in MLP: n_state=3072 (4 * n_embd)
     super(PositionWiseFeedforward, self).__init__()
     self.projA = nn.Linear(indim, dim)
     self.projB = nn.Linear(dim, indim)
     self.act = activation()
     self.dropout = q.RecDropout(dropout, shareaxis=1)
     self.indim, self.dim = indim, dim
     self.reset_parameters()
예제 #3
0
 def __init__(self,
              dim=512,
              kdim=None,
              vdim=None,
              innerdim=None,
              maxlen=512,
              numlayers=6,
              numheads=8,
              activation=nn.ReLU,
              embedding_dropout=0.,
              attention_dropout=0.,
              residual_dropout=0.,
              scale=True,
              relpos=False,
              posemb=None,
              **kw):
     """
     :param dim:     see MultiHeadAttention
     :param kdim:    see MultiHeadAttention
     :param vdim:    see MultiHeadAttention
     :param maxlen:  see MultiHeadAttention
     :param numlayers:   number of TransformerEncoderBlock layers used
     :param numheads:    see MultiHeadAttention
     :param activation:  which activation function to use in positionwise feedforward layers
     :param embedding_dropout:   dropout rate on embedding. Time-shared dropout. Not applied to position embeddings
     :param attention_dropout:   see MultiHeadAttention
     :param residual_dropout:    dropout rate on outputs of attention and feedforward layers
     :param scale:   see MultiHeadAttention
     :param relpos:  see MultiHeadAttention
     :param posemb:  if specified, must be a nn.Embedding-like, embeds position indexes in the range 0 to maxlen
     :param kw:
     """
     super(TransformerEncoder, self).__init__(**kw)
     self.maxlen = maxlen
     self.posemb = posemb
     self.embdrop = q.RecDropout(p=embedding_dropout, shareaxis=1)
     self.layers = nn.ModuleList([
         TransformerEncoderBlock(dim,
                                 kdim=kdim,
                                 vdim=vdim,
                                 innerdim=innerdim,
                                 numheads=numheads,
                                 activation=activation,
                                 attention_dropout=attention_dropout,
                                 residual_dropout=residual_dropout,
                                 scale=scale,
                                 maxlen=maxlen,
                                 relpos=relpos) for _ in range(numlayers)
     ])
예제 #4
0
    def __init__(self,
                 dim=512,
                 kdim=None,
                 vdim=None,
                 innerdim=None,
                 maxlen=512,
                 numlayers=6,
                 numheads=8,
                 activation=nn.ReLU,
                 embedding_dropout=0.,
                 attention_dropout=0.,
                 residual_dropout=0.,
                 scale=True,
                 noctx=False,
                 relpos=False,
                 posemb=None,
                 **kw):
        """
        :param noctx:   if False, no context should be given to forward(), see also TransformerDecoderBlock
        """
        super(TransformerDecoder, self).__init__(**kw)
        self.maxlen = maxlen
        self.noctx = noctx
        self.posemb = posemb
        self.embdrop = q.RecDropout(p=embedding_dropout, shareaxis=1)
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(dim,
                                    kdim=kdim,
                                    vdim=vdim,
                                    innerdim=innerdim,
                                    numheads=numheads,
                                    activation=activation,
                                    attention_dropout=attention_dropout,
                                    residual_dropout=residual_dropout,
                                    scale=scale,
                                    noctx=noctx,
                                    maxlen=maxlen,
                                    relpos=relpos) for _ in range(numlayers)
        ])

        self._cell_mode = False
        self._posoffset = 0
예제 #5
0
    def __init__(self,
                 indim=None,
                 kdim=None,
                 vdim=None,
                 bidir=True,
                 numheads=None,
                 attention_dropout=0.,
                 residual_dropout=0.,
                 scale=True,
                 maxlen=512,
                 relpos=False,
                 **kw):
        """

        :param indim:   input dimension (also output is of this dimension)
        :param kdim:    dimension to use for key (and query) vectors. if unspecified, indim is used
        :param vdim:    dimension to use for value vectors. if unspecified, indim is used
        :param bidir:   if False, applies causality mask to prevent use of information from future time steps (left-to-right mode)
        :param numheads:    number of attention heads
        :param attention_dropout:   dropout rate to apply on the attention probabilities
        :param residual_dropout:    dropout rate to apply on the output vectors. Residual dropout is shared across time
        :param scale:   if True, attention is scaled
        :param maxlen:  maximum length of sequences to support. Necessary for relative position encodings
        :param relpos:  if True, does relative position encoding. If "full", does more TODO
        :param kw:
        """
        super(MultiHeadAttention, self).__init__(**kw)

        self.numheads, self.indim = numheads, indim
        self.bidir, self.scale = bidir, scale
        vdim = indim if vdim is None else vdim
        kdim = indim if kdim is None else kdim

        self.d_k = kdim // numheads  # dim per head in key and query
        self.d_v = vdim // numheads

        self.q_proj = nn.Linear(indim, numheads * self.d_k)
        self.k_proj = nn.Linear(indim, numheads * self.d_k)
        self.v_proj = nn.Linear(indim, numheads * self.d_v)

        # self.qkv_proj = nn.Linear(indim, numheads * (self.d_k * 2 + self.d_v))

        self.relpos = relpos
        self.relpos_emb = None
        self._cache_relpos_vec = None
        self._cache_relpos_sizes = None
        self.relpos_k_proj = None
        self.maxlen = maxlen
        if relpos is True or relpos == "full":
            # print("using simple relative position")
            waves = get_sinusoid_encoding_table(maxlen, indim, start=-maxlen)
            self.relpos_emb = torch.nn.Embedding.from_pretrained(waves,
                                                                 freeze=True)
            if relpos == "full":  # TODO: test
                self.relpos_k_proj = nn.Linear(
                    indim,
                    numheads * self.d_k)  # projecting for rel position keys
                self.relpos_u = torch.nn.Parameter(
                    torch.empty(numheads, self.d_k))
                self.relpos_v = torch.nn.Parameter(
                    torch.empty(numheads, self.d_k))

        self.vw_proj = nn.Linear(vdim, indim)

        self.attn_dropout = nn.Dropout(attention_dropout)
        self.resid_dropout = q.RecDropout(residual_dropout, shareaxis=1)

        self._cell_mode = False  # True if in cell mode --> saves previous and expects seqlen 1
        self._horizon = None
        self._prev_k = None  # (batsize, seqlen, numheads, dim)
        self._prev_v = None
        self._prev_mask = None

        self.reset_parameters()