Esempio n. 1
0
    def build_copy_maps(self,
                        inp_vocab: Vocab,
                        str_action_re=re.compile(r"^([^_].*)$")):
        self.inp_vocab = inp_vocab
        self.register_buffer(
            "_inp_to_act",
            torch.zeros(inp_vocab.number_of_ids(), dtype=torch.long))
        self.register_buffer(
            "_act_to_inp",
            torch.zeros(self.out_vocab.number_of_ids(), dtype=torch.long))

        # for COPY, initialize mapping from input node vocab (sgb.vocab) to output action vocab (qgb.vocab_actions)
        self._build_copy_maps(str_action_re=str_action_re)

        # compute action mask from input: actions that are doable using input copy actions are 1, others are 0
        actmask = torch.zeros(self.out_vocab.number_of_ids(),
                              dtype=torch.uint8)
        actmask.index_fill_(0, self._inp_to_act, 1)
        actmask[0] = 0
        self.register_buffer("_inp_actmask", actmask)

        # rare actions
        self.rare_token_ids = self.out_vocab.rare_ids
        self.register_buffer("gen_mask", None)
        if len(self.rare_token_ids) > 0:
            gen_mask = torch.ones(self.out_vocab.number_of_ids())
            for rare_token_id in self.rare_token_ids:
                gen_mask[rare_token_id] = 0
            self.register_buffer("gen_mask", gen_mask)
    def __init__(self, dim, vocab:Vocab=None, numlayers:int=6, numheads:int=6,
                 dropout:float=0., maxpos=512, bertname="bert-base-uncased", **kw):
        super(TransformerTagger, self).__init__(**kw)
        self.vocab = vocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        config = TransformerConfig(vocab_size=self.vocabsize, d_model=self.dim, d_ff=self.dim * 4,
                                   num_layers=numlayers, num_heads=numheads, dropout_rate=dropout)

        decoder_config = deepcopy(config)
        decoder_config.is_decoder = True
        self.decoder = RelativePositionTransformer(decoder_config)

        self.out = torch.nn.Linear(self.dim, self.vocabsize)

        vocab_mask = torch.ones(self.vocabsize)
        for excl_token in self.exclude:
            if excl_token in self.vocab:
                vocab_mask[self.vocab[excl_token]] = 0
        self.register_buffer("vocab_mask", vocab_mask)

        self.bertname = bertname
        self.bert_model = BertModel.from_pretrained(self.bertname)
        def set_dropout(m:torch.nn.Module):
            if isinstance(m, torch.nn.Dropout):
                m.p = dropout
        self.bert_model.apply(set_dropout)

        self.adapter = None
        if self.bert_model.config.hidden_size != decoder_config.d_model:
            self.adapter = torch.nn.Linear(self.bert_model.config.hidden_size, decoder_config.d_model, bias=False)

        self.reset_parameters()
Esempio n. 3
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 inpvocab: Vocab = None,
                 numlayers: int = 6,
                 mode="normal",
                 dropout: float = 0.,
                 worddropout: float = 0.,
                 **kw):
        super(GRUDecoderCell, self).__init__(**kw)
        self.vocab = vocab
        self.inpvocab = inpvocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.mode = mode

        self.dec_emb = torch.nn.Embedding(self.vocabsize + 3, self.dim)
        dims = [self.dim + self.dim] + [self.dim for _ in range(numlayers)]
        self.dec_stack = torch.nn.ModuleList(
            [torch.nn.GRUCell(dims[i], dims[i + 1]) for i in range(numlayers)])
        self.dropout = torch.nn.Dropout(dropout)
        self.attn_linQ = None
        self.attn_linK = None
        self.attn_linV = None
        # self.attn_linQ = torch.nn.Linear(self.dim, self.dim)
        # self.attn_linK = torch.nn.Linear(self.dim, self.dim)
        # self.attn_linV = torch.nn.Linear(self.dim, self.dim)

        self.preout = torch.nn.Linear(self.dim + self.dim, self.dim)
        self.out = torch.nn.Linear(self.dim, self.vocabsize + 3)

        inpvocabsize = inpvocab.number_of_ids()
        self.encoder_model = Encoder(inpvocabsize + 5,
                                     self.dim,
                                     int(self.dim / 2),
                                     num_layers=numlayers,
                                     dropout=dropout)

        self.adapter = None
        self.inpworddropout = WordDropout(
            worddropout, self.inpvocab[self.inpvocab.masktoken],
            [self.inpvocab[self.inpvocab.padtoken]])
        self.worddropout = WordDropout(worddropout,
                                       self.vocab[self.vocab.masktoken],
                                       [self.vocab[self.vocab.padtoken]])

        self.reset_parameters()
Esempio n. 4
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 numlayers: int = 6,
                 numheads: int = 6,
                 dropout: float = 0.,
                 maxpos=512,
                 bertname="bert-base-uncased",
                 baseline=False,
                 **kw):
        super(TransformerTagger, self).__init__(**kw)
        self.vocab = vocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.baseline = baseline
        config = TransformerConfig(vocab_size=self.vocabsize,
                                   d_model=self.dim,
                                   d_ff=self.dim * 4,
                                   num_layers=numlayers,
                                   num_heads=numheads,
                                   dropout_rate=dropout,
                                   use_relative_position=False)

        self.emb = torch.nn.Embedding(config.vocab_size, config.d_model)
        self.posemb = torch.nn.Embedding(maxpos, config.d_model)
        decoder_config = deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.use_causal_mask = baseline
        self.decoder = TransformerStack(decoder_config)

        if baseline:
            self.out = torch.nn.Linear(self.dim, self.vocabsize)
        else:
            self.out = torch.nn.Linear(self.dim * 2, self.vocabsize)
        # self.out = MOS(self.dim, self.vocabsize, K=mosk)

        vocab_mask = torch.ones(self.vocabsize)
        # for excl_token in self.exclude:
        #     if excl_token in self.vocab:
        #         vocab_mask[self.vocab[excl_token]] = 0
        self.register_buffer("vocab_mask", vocab_mask)

        self.bertname = bertname
        self.bert_model = BertModel.from_pretrained(self.bertname)
        # def set_dropout(m:torch.nn.Module):
        #     if isinstance(m, torch.nn.Dropout):
        #         m.p = dropout
        # self.bert_model.apply(set_dropout)

        self.adapter = None
        if self.bert_model.config.hidden_size != decoder_config.d_model:
            self.adapter = torch.nn.Linear(self.bert_model.config.hidden_size,
                                           decoder_config.d_model,
                                           bias=False)

        self.reset_parameters()
Esempio n. 5
0
    def __init__(self, h_dim: int, vocab: Vocab = None, **kw):
        super(_PtrGenOutput, self).__init__(**kw)
        # initialize modules
        self.gen_lin = torch.nn.Linear(h_dim, vocab.number_of_ids(), bias=True)
        self.copy_or_gen = torch.nn.Linear(h_dim, 2, bias=True)
        self.sm = torch.nn.Softmax(-1)
        self.logsm = torch.nn.LogSoftmax(-1)

        self.inp_vocab, self.out_vocab = None, vocab

        self.naningrad = torch.nn.Parameter(torch.zeros(1))
        self.naningrad2 = torch.nn.Parameter(torch.zeros(1))
Esempio n. 6
0
    def __init__(self,
                 h_dim: int,
                 inp_vocab: Vocab = None,
                 out_vocab: Vocab = None,
                 **kw):
        super(SumPtrGenOutputOLD, self).__init__(**kw)
        # initialize modules
        self.gen_lin = torch.nn.Linear(h_dim,
                                       out_vocab.number_of_ids(),
                                       bias=True)
        self.sm = torch.nn.Softmax(-1)
        self.logsm = torch.nn.LogSoftmax(-1)

        self.inp_vocab, self.out_vocab = inp_vocab, out_vocab

        self.register_buffer(
            "_inp_to_act",
            torch.zeros(self.inp_vocab.number_of_ids(), dtype=torch.long))
        self.register_buffer(
            "_act_from_inp",
            torch.zeros(out_vocab.number_of_ids(), dtype=torch.long))

        # for COPY, initialize mapping from input node vocab (sgb.vocab) to output action vocab (qgb.vocab_actions)
        self.build_copy_maps()

        # compute action mask from input: actions that are doable using input copy actions are 1, others are 0
        actmask = torch.zeros(out_vocab.number_of_ids(), dtype=torch.uint8)
        actmask.index_fill_(0, self._inp_to_act, 1)
        self.register_buffer("_inp_actmask", actmask)

        # rare actions
        self.rare_token_ids = out_vocab.rare_ids
        rare_id = 1
        if len(self.rare_token_ids) > 0:
            out_map = torch.arange(self.out_vocab.number_of_ids())
            for rare_token_id in self.rare_token_ids:
                out_map[rare_token_id] = rare_id
            self.register_buffer("out_map", out_map)
        else:
            self.register_buffer("out_map", None)
Esempio n. 7
0
    def __init__(self,
                 h_dim: int,
                 vocab: Vocab = None,
                 dropout: float = 0.,
                 **kw):
        super(BasicGenOutput, self).__init__(**kw)
        self.gen_lin = torch.nn.Linear(h_dim, vocab.number_of_ids(), bias=True)
        self.sm = torch.nn.Softmax(-1)
        self.logsm = torch.nn.LogSoftmax(-1)
        self.dropout = torch.nn.Dropout(dropout)

        self.vocab = vocab

        # rare output tokens
        self.rare_token_ids = vocab.rare_ids
        if len(self.rare_token_ids) > 0:
            out_mask = torch.ones(self.vocab.number_of_ids())
            for rare_token_id in self.rare_token_ids:
                out_mask[rare_token_id] = 0
            self.register_buffer("out_mask", out_mask)
        else:
            self.register_buffer("out_mask", None)
Esempio n. 8
0
def load_ds(traindomains=("restaurants",),
            testdomain="housing",
            min_freq=1,
            mincoverage=1,
            top_k=np.infty,
            nl_mode="bert-base-uncased",
            fullsimplify=False,
            onlyabstract=False,
            pretrainsetting="all+lex",    # "all", "lex" or "all+lex"
            finetunesetting="lex",        # "lex", "all", "min"
            ):
    """
    :param traindomains:
    :param testdomain:
    :param min_freq:
    :param mincoverage:
    :param top_k:
    :param nl_mode:
    :param fullsimplify:
    :param add_domain_start:
    :param onlyabstract:
    :param pretrainsetting:     "all": use all examples from every domain
                                "lex": use only lexical examples
                                "all+lex": use both
    :param finetunesetting:     "lex": use lexical examples
                                "all": use all training examples
                                "min": use minimal lexicon-covering set of examples
                            ! Test is always over the same original test set.
                            ! Validation is over a fraction of training data
    :return:
    """
    general_tokens = {
        "(", ")", "arg:~type", "arg:type", "op:and", "SW:concat", "cond:has",
        "arg:<=", "arg:<", "arg:>=", "arg:>", "arg:!=", "arg:=", "SW:superlative",
        "SW:CNT-arg:min", "SW:CNT-arg:<", "SW:CNT-arg:<=", "SW:CNT-arg:>=", "SW:CNT-arg:>",
        "SW:CNT-arg:max", "SW:CNT-arg:=", "arg:max",
    }

    def tokenize_and_add_start(t):
        tokens = tree_to_lisp_tokens(t)
        starttok = "@START@"
        tokens = [starttok] + tokens
        return tokens

    sourceex = []
    for traindomain in traindomains:
        ds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True,
                                    restore_reverse=DATA_RESTORE_REVERSE, validfrac=.10)\
            .load(domain=traindomain)
        sourceex += ds[(None, None, lambda x: x in ("train", "valid", "lexicon"))].map(lambda x: (x[0], x[1], x[2], traindomain)).examples       # don't use test examples

    testds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE)\
        .load(domain=testdomain)

    targetex = testds.map(lambda x: x + (testdomain,)).examples

    pretrainex = []
    if "all" in pretrainsetting.split("+"):
        pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "train"]
    if "lex" in pretrainsetting.split("+"):
        pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "lexicon"]

    pretrainvalidex = [(a, tokenize_and_add_start(b), "pretrainvalid", d) for a, b, c, d in sourceex if c == "valid"]

    if finetunesetting == "all":
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "train"]
    elif finetunesetting == "lex":
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "lexicon"]
    elif finetunesetting == "min":
        finetunetrainex = get_maximum_spanning_examples([(a, b, c, d) for a, b, c, d in targetex if c == "train"],
                                      mincoverage=mincoverage,
                                      loadedex=[e for e in pretrainex if e[2] == "pretrain"])
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in finetunetrainex]
    finetunevalidex = [(a, tokenize_and_add_start(b), "ftvalid", d) for a, b, c, d in targetex if c == "valid"]
    finetunetestex = [(a, tokenize_and_add_start(b), "fttest", d) for a, b, c, d in targetex if c == "test"]
    print(f"Using mode \"{finetunesetting}\" for finetuning data: "
          f"\n\t{len(finetunetrainex)} training examples")


    allex = pretrainex + pretrainvalidex + finetunetrainex + finetunevalidex + finetunetestex
    ds = Dataset(allex)

    if onlyabstract:
        et = get_lf_abstract_transform(ds[lambda x: x[3] != testdomain].examples)
        ds = ds.map(lambda x: (x[0], et(x[1]), x[2], x[3]))

    seqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID)
    seqenc = SequenceEncoder(vocab=seqenc_vocab, tokenizer=lambda x: x,
                             add_start_token=False, add_end_token=True)
    for example in ds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=example[2] in ("pretrain", "fttrain"))
    seqenc.finalize_vocab(min_freq=min_freq, top_k=top_k)

    generaltokenmask = torch.zeros(seqenc_vocab.number_of_ids(), dtype=torch.long)
    for token, tokenid in seqenc_vocab.D.items():
        if token in general_tokens:
            generaltokenmask[tokenid] = 1

    nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode)
    def tokenize(x):
        ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0],
               seqenc.convert(x[1], return_what="tensor"),
               x[2],
               x[0], x[1], x[3])
        return ret
    tds, ftds, vds, fvds, xds = ds[(None, None, "pretrain", None)].map(tokenize), \
                          ds[(None, None, "fttrain", None)].map(tokenize), \
                          ds[(None, None, "pretrainvalid", None)].map(tokenize), \
                          ds[(None, None, "ftvalid", None)].map(tokenize), \
                          ds[(None, None, "fttest", None)].map(tokenize)
    return tds, ftds, vds, fvds, xds, nl_tokenizer, seqenc, generaltokenmask
Esempio n. 9
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 inpvocab: Vocab = None,
                 numlayers: int = 6,
                 numheads: int = 6,
                 userelpos=False,
                 useabspos=True,
                 relposmode="basic",
                 relposrng=10,
                 dropout: float = 0.,
                 sidedrop=0.,
                 maxpos=512,
                 bertname="bert-base-uncased",
                 mode="normal",
                 priorweight=0.,
                 **kw):
        super(SetModel, self).__init__(**kw)
        self.vocab = vocab
        self.inpvocab = inpvocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.userelpos = userelpos
        self.relposrng = relposrng
        self.useabspos = useabspos

        self.out = torch.nn.Linear(self.dim, self.vocabsize)
        self.bertname = bertname
        if self.bertname.startswith("none") or self.bertname == "vanilla":
            self.encrelposemb = None
            if self.userelpos is True:
                if relposmode == "basic":
                    self.encrelposemb = BasicRelPosEmb(self.dim, relposrng)
                # elif relposmode == "mod":
                #     self.relposemb = ModRelPosEmb(self.dim, relposrng, levels=4)
                else:
                    raise Exception(f"Unrecognized relposmode '{relposmode}'")
            bname = "bert" + self.bertname[4:]
            if self.bertname == "vanilla":
                inpvocabsize = inpvocab.number_of_ids()
            else:
                tokenizer = AutoTokenizer.from_pretrained(bname)
                inpvocabsize = tokenizer.vocab_size
            encconfig = TransformerConfig(vocab_size=inpvocabsize,
                                          d_model=self.dim,
                                          d_ff=self.dim * 4,
                                          d_kv=int(self.dim / numheads),
                                          attention_dropout_rate=0.,
                                          num_layers=numlayers,
                                          num_heads=numheads,
                                          dropout_rate=dropout,
                                          sideways_dropout=sidedrop,
                                          vib_att=mode.replace(" ",
                                                               "") == "vibatt")
            encemb = TransformerEmbeddings(encconfig.vocab_size,
                                           encconfig.d_model,
                                           dropout=dropout,
                                           max_position_embeddings=maxpos,
                                           useabspos=useabspos)
            self.encoder_model = TransformerStack(encconfig,
                                                  encemb,
                                                  rel_emb=self.encrelposemb)
        else:
            self.encoder_model = BertModel.from_pretrained(
                self.bertname,
                hidden_dropout_prob=min(dropout, 0.2),
                attention_probs_dropout_prob=min(dropout, 0.1))
        self.adapter = None
        if self.encoder_model.config.hidden_size != self.dim:
            self.adapter = torch.nn.Linear(
                self.encoder_model.config.hidden_size, self.dim, bias=False)

        self.reset_parameters()

        self.bce = torch.nn.BCEWithLogitsLoss(reduction="none")

        self.mode = mode
        self.priorweight = priorweight

        if self.mode == "vib":
            self.vib_lin_mu = torch.nn.Linear(dim, dim)
            self.vib_lin_logvar = torch.nn.Linear(dim, dim)
Esempio n. 10
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 inpvocab: Vocab = None,
                 numlayers: int = 6,
                 numheads: int = 6,
                 userelpos=False,
                 useabspos=True,
                 relposmode="basic",
                 relposrng=10,
                 mode="normal",
                 dropout: float = 0.,
                 worddropout: float = 0.,
                 maxpos=512,
                 bertname="bert-base-uncased",
                 **kw):
        super(TransformerDecoderCell, self).__init__(**kw)
        self.vocab = vocab
        self.inpvocab = inpvocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.userelpos = userelpos
        self.relposrng = relposrng
        self.useabspos = useabspos
        self.mode = mode

        decconfig = TransformerConfig(vocab_size=self.vocabsize,
                                      d_model=self.dim,
                                      d_ff=self.dim * 4,
                                      d_kv=int(self.dim / numheads),
                                      num_layers=numlayers,
                                      num_heads=numheads,
                                      dropout_rate=dropout)

        self.dec_emb = torch.nn.Embedding(self.vocabsize, decconfig.d_model)
        self.slot_emb = torch.nn.Embedding(1, decconfig.d_model)

        self.relposemb = None
        if self.userelpos is True:
            if relposmode == "basic":
                self.relposemb = BasicRelPosEmb(self.dim, relposrng)
            # elif relposmode == "mod":
            #     self.relposemb = ModRelPosEmb(self.dim, relposrng, levels=4)
            else:
                raise Exception(f"Unrecognized relposmode '{relposmode}'")

        self.absposemb = None
        if self.relposemb is None or self.useabspos is True:
            self.absposemb = torch.nn.Embedding(maxpos, decconfig.d_model)

        decoder_config = deepcopy(decconfig)
        decoder_config.is_decoder = True
        decoder_config.use_causal_mask = True
        self.decoder = TransformerStackDecoder(decoder_config,
                                               rel_emb=self.relposemb)

        self.out = torch.nn.Linear(self.dim, self.vocabsize)

        vocab_mask = torch.ones(self.vocabsize)
        # for excl_token in self.exclude:
        #     if excl_token in self.vocab:
        #         vocab_mask[self.vocab[excl_token]] = 0
        self.register_buffer("vocab_mask", vocab_mask)

        self.bertname = bertname
        self.encrelposemb = None
        if self.bertname.startswith("none") or self.bertname == "vanilla":
            if self.userelpos is True:
                if relposmode == "basic":
                    self.encrelposemb = BasicRelPosEmb(self.dim, relposrng)
                # elif relposmode == "mod":
                #     self.relposemb = ModRelPosEmb(self.dim, relposrng, levels=4)
                else:
                    raise Exception(f"Unrecognized relposmode '{relposmode}'")
            bname = "bert" + self.bertname[4:]
            if self.bertname == "vanilla":
                inpvocabsize = inpvocab.number_of_ids()
                self.inpworddropout = WordDropout(
                    worddropout, self.inpvocab[self.inpvocab.masktoken],
                    [self.inpvocab[self.inpvocab.padtoken]])
            else:
                tokenizer = AutoTokenizer.from_pretrained(bname)
                inpvocabsize = tokenizer.vocab_size
                self.inpworddropout = WordDropout(
                    worddropout, self.inpvocab[self.inpvocab.masktoken], [
                        self.inpvocab["[CLS]"], self.inpvocab["[SEP]"],
                        self.inpvocab[self.inpvocab.padtoken]
                    ])
            encconfig = TransformerConfig(vocab_size=inpvocabsize,
                                          d_model=self.dim,
                                          d_ff=self.dim * 4,
                                          d_kv=int(self.dim / numheads),
                                          num_layers=numlayers,
                                          num_heads=numheads,
                                          dropout_rate=dropout)
            encemb = TransformerEmbeddings(encconfig.vocab_size,
                                           encconfig.d_model,
                                           dropout=dropout,
                                           max_position_embeddings=maxpos,
                                           useabspos=useabspos)
            self.encoder_model = TransformerStack(encconfig,
                                                  encemb,
                                                  rel_emb=self.encrelposemb)
        else:
            self.encoder_model = BertModel.from_pretrained(
                self.bertname,
                hidden_dropout_prob=min(dropout, 0.2),
                attention_probs_dropout_prob=min(dropout, 0.1))
            tokenizer = AutoTokenizer.from_pretrained(self.bertname)
            inpvocabsize = tokenizer.vocab_size
            self.inpvocab = Vocab()
            for tok, id in tokenizer.vocab.items():
                self.inpvocab.D[tok] = id
            self.inpvocab.masktoken = "[MASK]"
            self.inpvocab.unktoken = "[UNK]"
            self.inpvocab.padtoken = "[PAD]"
            self.inpworddropout = WordDropout(
                worddropout, self.inpvocab[self.inpvocab.masktoken], [
                    self.inpvocab["[CLS]"], self.inpvocab["[SEP]"],
                    self.inpvocab[self.inpvocab.padtoken]
                ])

        self.adapter = None
        if self.encoder_model.config.hidden_size != decoder_config.d_model:
            self.adapter = torch.nn.Linear(
                self.encoder_model.config.hidden_size,
                decoder_config.d_model,
                bias=False)

        self.worddropout = WordDropout(worddropout,
                                       self.vocab[self.vocab.masktoken],
                                       [self.vocab[self.vocab.padtoken]])

        self.reset_parameters()
Esempio n. 11
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 numlayers: int = 6,
                 numheads: int = 6,
                 userelpos=False,
                 useabspos=True,
                 relposmode="basic",
                 relposrng=10,
                 dropout: float = 0.,
                 maxpos=512,
                 weightmode="vanilla",
                 **kw):
        super(TransformerEncoder, self).__init__(**kw)
        self.vocab = vocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.userelpos = userelpos
        self.relposrng = relposrng
        self.useabspos = useabspos

        self.weightmode = weightmode
        if self.weightmode.startswith("none") or self.weightmode == "vanilla":
            self.encrelposemb = None
            if self.userelpos is True:
                if relposmode == "basic":
                    self.encrelposemb = BasicRelPosEmb(self.dim, relposrng)
                # elif relposmode == "mod":
                #     self.relposemb = ModRelPosEmb(self.dim, relposrng, levels=4)
                else:
                    raise Exception(f"Unrecognized relposmode '{relposmode}'")
            bname = "bert" + self.weightmode[4:]
            if self.weightmode == "vanilla":
                inpvocabsize = self.vocabsize
            else:
                tokenizer = AutoTokenizer.from_pretrained(bname)
                inpvocabsize = tokenizer.vocab_size
            config = TransformerConfig(vocab_size=inpvocabsize,
                                       d_model=self.dim,
                                       d_ff=self.dim * 4,
                                       d_kv=int(self.dim / numheads),
                                       num_layers=numlayers,
                                       num_heads=numheads,
                                       dropout_rate=dropout)
            encemb = TransformerEmbeddings(config.vocab_size,
                                           config.d_model,
                                           dropout=dropout,
                                           max_position_embeddings=maxpos,
                                           useabspos=useabspos)
            self.encoder_model = TransformerStack(config,
                                                  encemb,
                                                  rel_emb=self.encrelposemb)
        else:
            self.encoder_model = BertModel.from_pretrained(
                self.weightmode,
                hidden_dropout_prob=min(dropout, 0.2),
                attention_probs_dropout_prob=min(dropout, 0.1))
        self.adapter = None
        if self.encoder_model.config.hidden_size != self.dim:
            self.adapter = torch.nn.Linear(
                self.encoder_model.config.hidden_size, self.dim, bias=False)

        self.reset_parameters()
Esempio n. 12
0
    def __init__(self,
                 dim,
                 vocab: Vocab = None,
                 inpvocab: Vocab = None,
                 numlayers: int = 2,
                 numtmlayers=6,
                 mode="normal",
                 dropout: float = 0.,
                 worddropout: float = 0.,
                 numheads=6,
                 noencoder=False,
                 **kw):
        super(DecoderCell, self).__init__(**kw)
        self.vocab = vocab
        self.inpvocab = inpvocab
        self.vocabsize = vocab.number_of_ids()
        self.dim = dim
        self.mode = mode
        self.noencoder = noencoder
        self.numlayers = numlayers
        self.numtmlayers = numtmlayers

        self.dec_emb = torch.nn.Embedding(self.vocabsize + 3, self.dim)
        dims = [self.dim + self.dim] + [self.dim for _ in range(numlayers)]
        self.dec_stack = torch.nn.ModuleList(
            [torch.nn.GRUCell(dims[i], dims[i + 1]) for i in range(numlayers)])
        self.dropout = torch.nn.Dropout(dropout)
        self.attn_linQ = None
        self.attn_linK = None
        self.attn_linV = None
        # self.attn_linQ = torch.nn.Linear(self.dim, self.dim)
        # self.attn_linK = torch.nn.Linear(self.dim, self.dim)
        # self.attn_linV = torch.nn.Linear(self.dim, self.dim)

        self.preout = torch.nn.Linear(self.dim + self.dim, self.dim)
        self.preoutnonlin = torch.nn.CELU()
        if self.mode == "cont":
            pass
        else:
            self.out = torch.nn.Linear(self.dim, self.vocabsize + 3)

        inpvocabsize = inpvocab.number_of_ids()
        if not self.noencoder:
            encconfig = TransformerConfig(vocab_size=inpvocabsize,
                                          d_model=self.dim,
                                          d_ff=self.dim * 4,
                                          d_kv=int(self.dim / numheads),
                                          num_layers=self.numtmlayers,
                                          num_heads=numheads,
                                          dropout_rate=dropout)
            encemb = TransformerEmbeddings(encconfig.vocab_size,
                                           encconfig.d_model,
                                           dropout=dropout,
                                           max_position_embeddings=1000,
                                           useabspos=True)
            self.encoder_model = TransformerStack(encconfig, encemb)
            # self.encoder_model = Encoder(inpvocabsize+5, self.dim, int(self.dim/2), num_layers=numlayers, dropout=dropout)

        self.adapter = None
        self.inpworddropout = WordDropout(
            worddropout, self.inpvocab[self.inpvocab.masktoken],
            [self.inpvocab[self.inpvocab.padtoken]])
        self.worddropout = WordDropout(worddropout,
                                       self.vocab[self.vocab.masktoken],
                                       [self.vocab[self.vocab.padtoken]])

        self.lenlin = torch.nn.Linear(self.dim * 2, self.dim)
        self.lennonlin = torch.nn.CELU()
        self.lenbias = torch.nn.Linear(self.dim, 1)
        self.lenscale = torch.nn.Linear(self.dim, 1)

        self.reset_parameters()