コード例 #1
0
ファイル: word.py プロジェクト: nju-websoft/SkeletonKBQA
 def override(self, emb:WordEmb, selectwords=None):
     """
     :param emb:     WordEmb whose entries will override the base (and previous overrides)
     :param selectwords:   which words to override. If None, emb.D's keys are used.
     :return:
     """
     if hasattr(emb, "word_dropout") and emb.word_dropout is not None:
         print("WARNING: word dropout of base will be applied before output. "
               "Word dropout on the emb provided here will be applied before that "
               "and the combined effect may result in over-dropout.")
     if selectwords is None:
         selectwords = set(emb.D.keys())
     # ensure that emb.D maps words in self.base.D to the same id as self.base.D
     selid = len(self.other_embs) + 1
     for k, v in self.D.items():
         if v > emb.weight.size(0):
             raise q.SumTingWongException("the override must contain all positions of base.D but doesn't have ('{}':{})".format(k, v))
         if k in emb.D and k in selectwords:
             if emb.D[k] != v:
                 raise q.SumTingWongException("the override emb must map same words to same id "
                                              "but '{}' maps to {} in emb.D and to {} in self.base.D"
                                              .format(k, emb.D[k], v))
             # update select_mask
             self.select_mask[v] = selid
     self.other_embs.append(emb)
     return self
コード例 #2
0
    def __init__(self,
                 mode="logits",
                 weight=None,
                 reduction="mean",
                 pos_weight=None,
                 maskid=0,
                 trueid=2,
                 **kw):
        """

        :param mode:        "logits" or "probs". If "probs", pos_weight must be None
        :param weight:
        :param reduction:
        :param pos_weight:
        :param maskid:
        :param trueid:
        :param kw:
        """
        super(AutomaskedBCELoss, self).__init__(**kw)
        self.mode = mode
        if mode == "logits":
            self.loss = torch.nn.BCEWithLogitsLoss(weight=weight,
                                                   reduction="none",
                                                   pos_weight=pos_weight)
        elif mode == "probs":
            assert (pos_weight is None)
            self.loss = torch.nn.BCELoss(weight=weight, reduction="none")
        else:
            raise q.SumTingWongException("unknown mode: {}".format(mode))
        self.reduction = reduction
        self.maskid, self.trueid = maskid, trueid
コード例 #3
0
ファイル: word.py プロジェクト: nilesh-c/qelos
 def __init__(self, base, merge, mode="sum"):
     super(MergedWordVecBase, self).__init__(base.D)
     self.base = base
     self.merg = merge
     self.mode = mode
     if not mode in ("sum", "cat"):
         raise q.SumTingWongException(
             "{} merge mode not suported".format(mode))
コード例 #4
0
ファイル: word.py プロジェクト: nilesh-c/qelos
 def merge(self, wordemb, mode="sum"):
     """
     Merges this embedding with provided embedding using the provided mode.
     The dictionary of provided embedding must be identical to this embedding.
     """
     if not wordemb.D == self.D:
         raise q.SumTingWongException("must have identical dictionary")
     return MergedWordEmb(self, wordemb, mode=mode)
コード例 #5
0
ファイル: data.py プロジェクト: nju-websoft/SkeletonKBQA
 def add(self, x, y):   # x is a string of one example
     if self._matrix is not None:
         raise q.SumTingWongException("can't add to finalized {}".format(self.__class__.__name__))
     xtokens = self.tokenizer.tokenize(x)
     ytokens = self.tokenizer.tokenize(y)
     if len(xtokens) + len(ytokens) > self.maxlen - 3:
         _truncate_seq_pair(xtokens, ytokens, self.maxlen - 3)
     tokens = [self.start_token] + xtokens + [self.sep_token] + ytokens + [self.end_token]
     self._x_maxlen = max(self._x_maxlen, len(tokens))
     self.sep_positions.append((len(xtokens) + 1, len(tokens)))
     self.mattokens.append(tokens)
コード例 #6
0
def datacat(datasets, mode=1):
    """
    Concatenates given pytorch datasets. If mode == 0, creates pytorch ConcatDataset, if mode == 1, creates a MultiDataset.
    :return:
    """
    if mode == 0:
        return torch.utils.data.dataset.ConcatDataset(datasets)
    elif mode == 1:
        return MultiDatasets(datasets)
    else:
        raise q.SumTingWongException("mode {} not recognized".format(mode))
コード例 #7
0
ファイル: word.py プロジェクト: nilesh-c/qelos
 def forward(self, x):
     base_emb, base_msk = self.base(x)
     merg_emb, merg_msk = self.merg(x)
     if self.mode == "sum":
         emb = base_emb + merg_emb
         msk = base_msk  # since dictionaries are identical
     elif self.mode == "cat":
         emb = torch.cat([base_emb, merg_emb], 1)
         msk = base_msk
     else:
         raise q.SumTingWongException()
     return emb, msk
コード例 #8
0
ファイル: data.py プロジェクト: nju-websoft/SkeletonKBQA
 def add(self, x):   # x is a string of one example
     if self._matrix is not None:
         raise q.SumTingWongException("can't add to finalized {}".format(self.__class__.__name__))
     xtokens = self.tokenizer.tokenize(x)
     addspace = sum([1 if tok is not None else 0 for tok in [self.start_token, self.end_token]])
     if len(xtokens) > self.maxlen - addspace:
         xtokens = xtokens[:self.maxlen-addspace]
     if self.start_token is not None:
         xtokens = [self.start_token] + xtokens
     if self.end_token is not None:
         xtokens = xtokens + [self.end_token]
     self._x_maxlen = max(self._x_maxlen, len(xtokens))
     self.mattokens.append(xtokens)
コード例 #9
0
ファイル: word.py プロジェクト: nilesh-c/qelos
 def forward(self, x, mask=None):
     if self.mode == "cat":  # need to split up input
         basex = x[:, :self.base.vecdim]
         mergx = x[:, self.base.vecdim:]
         # TODO: not all wordlinouts have .vecdim
     elif self.mode == "sum":
         basex, mergx = x, x
     else:
         raise q.SumTingWongException()
     baseres = self.base(basex, mask=mask)
     mergres = self.merg(mergx, mask=mask)
     res = baseres + mergres
     return res
コード例 #10
0
ファイル: model.py プロジェクト: lukovnikov/qelos-util
 def forward(self, qry, ctx, ctx_mask=None):
     """
     :param qry:         (batsize, dim) or (batsize, zeqlen, dim)
     :param ctx:         (batsize, seqlen, dim)
     :param ctx_mask:
     :return:
     """
     if qry.dim() == 2:
         ret = torch.einsum("bd,bsd->bs", [qry, ctx])
     elif qry.dim() == 3:
         ret = torch.einsum("bzd,bsd->bzs", [qry, ctx])
     else:
         raise q.SumTingWongException(
             "qry has unsupported dimension: {}".format(qry.dim()))
     return ret
コード例 #11
0
ファイル: prepareflmats.py プロジェクト: nilesh-c/qelos
def run(p="../../../../datasets/webqsp/webqsp.all.butd.vnt.info",
        outp="../../../../datasets/webqsp/flmats/"):
    tt = q.ticktock("matrix builder")

    # load info file
    tt.tick("loading info")
    info = pickle.load(open(p))
    tt.tock("info loaded")

    # separate
    entityinfo = {}
    relationinfo = {}
    for key, val in info.items():
        if category(key) == ENT:
            entityinfo[key] = val
        elif category(key) == REL:
            relationinfo[key] = val
        else:
            raise q.SumTingWongException()

    # build and save entity matrices
    edic, names, nameschars, aliases, typenames, notabletypenames, types \
        = build_entity_matrices(entityinfo)
    pickle.dump(edic, open(outp + "webqsp.entity.dic", "w"))
    names.save(outp + "webqsp.entity.names.sm")
    nameschars.save(outp + "webqsp.entity.names.char.sm")
    # aliases.save(outp+"webqsp.entity.aliases.sm")
    typenames.save(outp + "webqsp.entity.typenames.sm")
    notabletypenames.save(outp + "webqsp.entity.notabletypes.sm")
    types.save(outp + "webqsp.entity.types.sm")

    # build and save relation matrices
    rdic, names, domains, ranges, domainids, rangeids, urlwords, urltokens \
        = build_relation_matrices(relationinfo)
    pickle.dump(rdic, open(outp + "webqsp.relation.dic", "w"))
    basep = outp + "webqsp.relation."
    names.save(basep + "names.sm")
    domains.save(basep + "domains.sm")
    ranges.save(basep + "ranges.sm")
    domainids.save(basep + "domainids.sm")
    rangeids.save(basep + "rangeids.sm")
    urlwords.save(basep + "urlwords.sm")
    urltokens.save(basep + "urltokens.sm")

    # reload
    tt.tick("reloading")
    enamesreloaded = q.StringMatrix.load(outp + "webqsp.entity.typenames.sm")
    tt.tock("reloaded")
コード例 #12
0
 def get_logprob_of_sampled_alphas(self):
     if self.hard is False:
         raise q.SumTingWongException(
             "Use this only for RL on hard attention (must be in hard mode)."
         )
     probs = self.prevatts_probs * self.prevatts_samples + (
         1 - self.prevatts_probs) * (1 - self.prevatts_samples)
     logprobs = torch.log(probs)
     logprobs = logprobs * self.prevatts_mask  # mask the logprobs
     average_within_timestep = True
     if average_within_timestep:
         totals = self.prevatts_mask.sum(2) + 1e-6
         logprobs = logprobs.sum(2) / totals
     else:
         logprobs = logprobs.mean(2)
     return logprobs[:,
                     2:]  # (batsize, seqlen)  -- decoder mask should be applied on this later
コード例 #13
0
ファイル: loss.py プロジェクト: nju-websoft/SkeletonKBQA
 def __init__(self,
              weight=None,
              reduction="mean",
              ignore_index=-100,
              mode="logits",
              **kw):
     super(CELoss, self).__init__(**kw)
     self.mode = mode
     if mode in ("logprobs", "probs"):
         self.ce = torch.nn.NLLLoss(weight=weight,
                                    reduction=reduction,
                                    ignore_index=ignore_index)
     elif mode == "logits":
         self.ce = torch.nn.CrossEntropyLoss(weight=weight,
                                             reduction=reduction,
                                             ignore_index=ignore_index)
     else:
         raise q.SumTingWongException("unknown mode {}".format(mode))
コード例 #14
0
ファイル: seq2seq.py プロジェクト: nilesh-c/qelos
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100,
                 attmode="bilin", decsplit=False, **kw):
    """ makes decoder
    # attention cell decoder that accepts VNT !!!
    """
    ctxdim = ctxdim if not decsplit else ctxdim // 2
    coreindim = embdim + ctxdim     # if ctx_to_decinp is True else embdim

    coretocritdim = dim if not decsplit else dim // 2
    critdim = dim + embdim          # if decinp_to_att is True else dim

    if attmode == "bilin":
        attention = q.Attention().bilinear_gen(ctxdim, critdim)
    elif attmode == "fwd":
        attention = q.Attention().forward_gen(ctxdim, critdim)
    else:
        raise q.SumTingWongException()

    attcell = q.AttentionDecoderCell(attention=attention,
                                     embedder=emb,
                                     core=q.RecStack(
                                         q.GRUCell(coreindim, dim),
                                         q.GRUCell(dim, dim),
                                     ),
                                     smo=q.Stack(
                                         q.argsave.spec(mask={"mask"}),
                                         lin,
                                         q.argmap.spec(0, mask=["mask"]),
                                         q.LogSoftmax(),
                                         q.argmap.spec(0),
                                     ),
                                     ctx_to_decinp=True,
                                     ctx_to_smo=True,
                                     state_to_smo=True,
                                     decinp_to_att=True,
                                     state_split=decsplit)
    return attcell.to_decoder()
コード例 #15
0
def gen_datasets(which="geo"):
    pprefix = "../data/"
    if which == "geo":
        pprefix = pprefix + "geoqueries/jia2016/"
        trainp = pprefix + "train.txt"
        validp = pprefix + "test.txt"
        testp = pprefix + "test.txt"
    elif which == "atis":
        pprefix += "atis/jia2016/"
        trainp = pprefix + "train.txt"
        validp = pprefix + "dev.txt"
        testp = pprefix + "test.txt"
    elif which == "jobs":
        assert(False) # jia didn't do jobs
        pprefix += "jobqueries"
        trainp = pprefix + "train.txt"
        validp = pprefix + "test.txt"
        testp = pprefix + "test.txt"
    else:
        raise q.SumTingWongException("unknown dataset")

    nlsm = q.StringMatrix(indicate_start_end=True)
    nlsm.tokenize = lambda x: x.split()
    flsm = q.StringMatrix(indicate_start_end=True if which == "jobs" else False)
    flsm.tokenize = lambda x: x.split()
    devstart, teststart, i = 0, 0, 0
    trainwords = set()
    trainwordcounts = {}
    testwords = set()
    trainwords_fl = set()
    trainwordcounts_fl = {}
    testwords_fl = set()
    with open(trainp) as tf, open(validp) as vf, open(testp) as xf:
        for line in tf:
            line_nl, line_fl = line.strip().split("\t")
            line_fl = line_fl.replace("' ", "")
            # line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            trainwords |= set(line_nl.split())
            for word in set(line_nl.split()):
                if word not in trainwordcounts:
                    trainwordcounts[word] = 0
                trainwordcounts[word] += 1
            trainwords_fl |= set(line_fl.split())
            for word in set(line_fl.split()):
                if word not in trainwordcounts_fl:
                    trainwordcounts_fl[word] = 0
                trainwordcounts_fl[word] += 1
            i += 1
        devstart = i
        for line in vf:
            line_nl, line_fl = line.strip().split("\t")
            line_fl = line_fl.replace("' ", "")
            # line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            i += 1
        teststart = i
        for line in xf:
            line_nl, line_fl = line.strip().split("\t")
            line_fl = line_fl.replace("' ", "")
            # line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            testwords |= set(line_nl.split())
            testwords_fl |= set(line_fl.split())
            i += 1
    nlsm.finalize()
    flsm.finalize()

    # region get gate sup
    gatesups = torch.zeros(flsm.matrix.shape[0], flsm.matrix.shape[1]+1, dtype=torch.long)
    for i in range(nlsm.matrix.shape[0]):
        nl_sent = nlsm[i].split()
        fl_sent = flsm[i].split()
        inid = False
        for j, fl_sent_token in enumerate(fl_sent):
            if re.match("_\w+id", fl_sent_token):
                inid = True
            elif fl_sent_token == ")":
                inid = False
            elif fl_sent_token == "(":
                pass
            else:
                if inid:
                    if fl_sent_token in nl_sent:
                        gatesups[i, j] = 1




    # endregion

    # region print analysis
    print("{} unique words in train, {} unique words in test, {} in test but not in train"
          .format(len(trainwords), len(testwords), len(testwords - trainwords)))
    print(testwords - trainwords)
    trainwords_once = set([k for k, v in trainwordcounts.items() if v < 2])
    print("{} unique words in train that occur only once ({} of them is in test)".format(len(trainwords_once), len(trainwords_once & testwords)))
    print(trainwords_once)
    trainwords_twice = set([k for k, v in trainwordcounts.items() if v < 3])
    print("{} unique words in train that occur only twice ({} of them is in test)".format(len(trainwords_twice), len(trainwords_twice & testwords)))
    rarerep = trainwords_once | (testwords - trainwords)
    print("{} unique rare representation words".format(len(rarerep)))
    print(rarerep)

    trainwords_fl_once = set([k for k, v in trainwordcounts_fl.items() if v < 2])
    rarerep_fl = trainwords_fl_once | (testwords_fl - trainwords_fl)
    print("{} unique rare rep words in logical forms".format(len(rarerep_fl)))
    print(rarerep_fl)
    # endregion

    # endregion create datasets
    nlmat = torch.tensor(nlsm.matrix).long()
    flmat = torch.tensor(flsm.matrix).long()
    gold = torch.tensor(flsm.matrix[:, 1:]).long()
    gold = torch.cat([gold, torch.zeros_like(gold[:, 0:1])], 1)
    tds = torch.utils.data.TensorDataset(nlmat[:devstart], flmat[:devstart], gold[:devstart], gatesups[:devstart][:, 1:])
    vds = torch.utils.data.TensorDataset(nlmat[devstart:teststart], flmat[devstart:teststart], gold[devstart:teststart])
    xds = torch.utils.data.TensorDataset(nlmat[teststart:], flmat[teststart:], gold[teststart:])
    # endregion
    return (tds, vds, xds), nlsm.D, flsm.D, rarerep, rarerep_fl
コード例 #16
0
def gen_datasets(which="geo"):
    pprefix = "../data/"
    if which == "geo":
        pprefix = pprefix + "geoqueries/dong2016/"
        trainp = pprefix + "train.txt"
        validp = pprefix + "test.txt"
        testp = pprefix + "test.txt"
    elif which == "atis":
        pprefix += "atis/dong2016/"
        trainp = pprefix + "train.txt"
        validp = pprefix + "dev.txt"
        testp = pprefix + "test.txt"
    elif which == "jobs":
        pprefix += "jobqueries/dong2016/"
        trainp = pprefix + "train.txt"
        validp = pprefix + "test.txt"
        testp = pprefix + "test.txt"
    else:
        raise q.SumTingWongException("unknown dataset")

    nlsm = q.StringMatrix(indicate_start_end=True)
    nlsm.tokenize = lambda x: x.split()
    flsm = q.StringMatrix(
        indicate_start_end=True if which == "jobs" else False)
    flsm.tokenize = lambda x: x.split()
    devstart, teststart, i = 0, 0, 0
    with open(trainp) as tf, open(validp) as vf, open(testp) as xf:
        for line in tf:
            line_nl, line_fl = line.strip().split("\t")
            line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            i += 1
        devstart = i
        for line in vf:
            line_nl, line_fl = line.strip().split("\t")
            line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            i += 1
        teststart = i
        for line in xf:
            line_nl, line_fl = line.strip().split("\t")
            line_nl = " ".join(line_nl.split(" ")[::-1])
            nlsm.add(line_nl)
            flsm.add(line_fl)
            i += 1
    nlsm.finalize()
    flsm.finalize()

    nlmat = torch.tensor(nlsm.matrix).long()
    flmat = torch.tensor(flsm.matrix).long()
    gold = torch.tensor(flsm.matrix[:, 1:]).long()
    gold = torch.cat([gold, torch.zeros_like(gold[:, 0:1])], 1)
    tds = torch.utils.data.TensorDataset(nlmat[:devstart], flmat[:devstart],
                                         gold[:devstart])
    vds = torch.utils.data.TensorDataset(nlmat[devstart:teststart],
                                         flmat[devstart:teststart],
                                         gold[devstart:teststart])
    xds = torch.utils.data.TensorDataset(nlmat[teststart:], flmat[teststart:],
                                         gold[teststart:])
    return (tds, vds, xds), nlsm.D, flsm.D
コード例 #17
0
def iscuda(x):
    if isinstance(x, torch.nn.Module):
        params = list(x.parameters())
        return params[0].is_cuda
    else:
        raise q.SumTingWongException("unsupported type")
コード例 #18
0
def train_batch_distill(batch=None,
                        model=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        batch_number=-1,
                        max_batches=0,
                        current_epoch=0,
                        max_epochs=0,
                        on_start=tuple(),
                        on_before_optim_step=tuple(),
                        on_after_optim_step=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param _batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :param mbase:           base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used.
    :param goldgetter:      takes the gold and produces a softgold
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_batch, **kwargs)

    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)

    batch_in = batch[:-1]
    gold = batch[-1]

    # run batch_in through teacher model to get teacher output distributions
    if goldgetter is not None:
        softgold = goldgetter(gold)
    elif mbase is not None:
        mbase.eval()
        q.batch_reset(mbase)
        with torch.no_grad():
            softgold = mbase(*batch_in)
    else:
        raise q.SumTingWongException(
            "goldgetter and mbase can not both be None")

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, (softgold, gold))
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
コード例 #19
0
ファイル: word.py プロジェクト: nilesh-c/qelos
 def merge(self, x, mode="sum"):
     if not self.D == x.D:
         raise q.SumTingWongException()
     return MergedWordLinout(self, x, mode=mode)
コード例 #20
0
def load_data(p="../../data/buboqa/data/bertified_dataset.npz",
              which="span/io",
              retrelD=False,
              retrelcounts=False,
              rettokD=False,
              datafrac=1.,
              wordlevel=False):
    """
    :param p:       where the stored matrices are
    :param which:   which data to include in output datasets
                        "span/io": O/I annotated spans,
                        "span/borders": begin and end positions of span
                        "rel+io": what relation (also gives "spanio" outputs to give info where entity is supposed to be (to ignore it))
                        "rel+borders": same, but gives "spanborders" instead
                        "all": everything
    :return:
    """
    tt = q.ticktock("dataloader")
    if wordlevel:
        tt.tick("loading original data word-level stringmatrix")
        wordmat, wordD, (word_devstart, word_teststart) = load_word_mat()
        twordmat, vwordmat, xwordmat = wordmat[:word_devstart], wordmat[
            word_devstart:word_teststart], wordmat[word_teststart:]
        tt.tock("loaded stringmatrix")
    tt.tick("loading saved np mats")
    data = np.load(p)
    print(data.keys())
    relD = data["relD"].item()
    revrelD = {v: k for k, v in relD.items()}
    devstart = data["devstart"]
    teststart = data["teststart"]
    if wordlevel:
        assert (devstart == word_devstart)
        assert (teststart == word_teststart)
    tt.tock("mats loaded")
    tt.tick("loading BERT tokenizer")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tt.tock("done")

    if wordlevel:
        tokD = wordD
    else:
        tokD = tokenizer.vocab

    def pp(i):
        tokrow = data["tokmat"][i]
        iorow = [xe - 1 for xe in data["iomat"][i] if xe != 0]
        ioborderrow = data["ioborders"][i]
        rel_i = data["rels"][i]
        tokrow = tokenizer.convert_ids_to_tokens(
            [tok for tok in tokrow if tok != 0])
        # print(" ".join(tokrow))
        print(tabulate([range(len(tokrow)), tokrow, iorow]))
        print(ioborderrow)
        print(revrelD[rel_i])

    # tt.tick("printing some examples")
    # for k in range(10):
    #     print("\nExample {}".format(k))
    #     pp(k)
    # tt.tock("printed some examples")

    # datasets
    tt.tick("making datasets")
    if which == "span/io":
        selection = ["tokmat", "iomat"]
    elif which == "span/borders":
        selection = ["tokmat", "ioborders"]
    elif which == "rel+io":
        selection = ["tokmat", "iomat", "rels"]
    elif which == "rel+borders":
        selection = ["tokmat", "ioborders", "rels"]
    elif which == "all":
        selection = ["tokmat", "iomat", "ioborders", "rels"]
    else:
        raise q.SumTingWongException("unknown which mode: {}".format(which))

    if wordlevel:
        tokmat = wordmat
    else:
        tokmat = data["tokmat"]

    selected = [
        torch.tensor(data[sel] if sel != "tokmat" else tokmat).long()
        for sel in selection
    ]
    tselected = [sel[:devstart] for sel in selected]
    vselected = [sel[devstart:teststart] for sel in selected]
    xselected = [sel[teststart:] for sel in selected]

    if datafrac <= 1.:
        # restrict data such that least relations are unseen
        # get relation counts
        trainrels = data["rels"][:devstart]
        uniquerels, relcounts = np.unique(data["rels"][:devstart],
                                          return_counts=True)
        relcountsD = dict(zip(uniquerels, relcounts))
        relcounter = dict(zip(uniquerels, [0] * len(uniquerels)))
        totalcap = int(datafrac * len(trainrels))
        capperrel = max(relcountsD.values())

        def numberexamplesincluded(capperrel_):
            numberexamplesforcap = np.clip(relcounts, 0, capperrel_).sum()
            return numberexamplesforcap

        while capperrel > 0:  # TODO do binary search
            numexcapped = numberexamplesincluded(capperrel)
            if numexcapped <= totalcap:
                break
            capperrel -= 1

        print("rel count cap is {}".format(capperrel))

        remainids = []
        for i in range(len(trainrels)):
            if len(remainids) >= totalcap:
                break
            if relcounter[trainrels[i]] > capperrel:
                pass
            else:
                relcounter[trainrels[i]] += 1
                remainids.append(i)
        print("{}/{} examples retained".format(len(remainids), len(trainrels)))
        tselected_new = [sel[remainids] for sel in tselected]
        if datafrac == 1.:
            for a, b in zip(tselected_new, tselected):
                assert (np.all(a == b))
        tselected = tselected_new

    traindata = TensorDataset(*tselected)
    devdata = TensorDataset(*vselected)
    testdata = TensorDataset(*xselected)

    ret = (traindata, devdata, testdata)
    if retrelD:
        ret += (relD, )
    if rettokD:
        ret += (tokD, )
    if retrelcounts:
        ret += data["relcounts"]
    tt.tock("made datasets")
    return ret
コード例 #21
0
ファイル: model.py プロジェクト: lukovnikov/qelos-util
    def load_weights_from_tf_checkpoint(self,
                                        ckpt_path,
                                        make_mlm_pred=False,
                                        verbose=True):
        if verbose:
            print("Loading tensorflow BERT weights from {}".format(ckpt_path))
        import tensorflow as tf

        # region from Hugging Face BERT
        init_vars = tf.train.list_variables(ckpt_path)
        names = []
        arrays = []
        for name, shape in init_vars:
            if verbose:
                print("Loading {} with shape {}".format(name, shape))
            array = tf.train.load_variable(ckpt_path, name)
            if verbose:
                print("Numpy array shape {}".format(array.shape))
            names.append(name)
            arrays.append(array)
        # endregion

        # load values from tf ckpt variable paths to our paths
        def mapname(a):
            """
            LayerNorm
                beta -> bias
                gamma -> weight
            (.+)_embeddings -> {1}_embeddings.weight
            /embeddings -> emb
            /encoder
                layer_(\d+) -> layers,{1}
                    attention/output/LayerNorm -> ln_slf
                    attention -> slf_attn
                        self/query -> q_proj
                        self/key -> k_proj
                        self/value -> v_proj
                        output/dense -> vw_proj
                    intermediate/dense -> mlp.projA
                    output/dense -> mlp.projB
                    output/LayerNorm -> ln_ff
                    """
            if re.match(".+LayerNorm.+", a):
                a = re.sub("LayerNorm/gamma$", "LayerNorm/weight", a)
                a = re.sub("LayerNorm/beta$", "LayerNorm/bias", a)
            if re.match(".+_embeddings$", a):
                a = re.sub("(.+_embeddings)$", "\g<1>/weight", a)
            # a = re.sub("kernel$", "weight", a)
            a = re.sub("^embeddings", "emb", a)
            if re.match("^encoder", a):
                if re.match("^encoder/layer_\d+", a):
                    a = re.sub("^(encoder/layer)_(\d+)",
                               "encoder/layers/\g<2>", a)
                    a = re.sub("attention/output/LayerNorm", "ln_slf", a)
                    if re.match(".+attention.+", a):
                        a = re.sub("attention/self/query", "attention/q_proj",
                                   a)
                        a = re.sub("attention/self/key", "attention/k_proj", a)
                        a = re.sub("attention/self/value", "attention/v_proj",
                                   a)
                        a = re.sub("attention/output/dense",
                                   "attention/vw_proj", a)
                        a = re.sub("attention", "slf_attn", a)
                    a = re.sub("intermediate/dense", "mlp/projA", a)
                    a = re.sub("output/dense", "mlp/projB", a)
                    a = re.sub("output/LayerNorm", "ln_ff", a)
            return a

        for name, array in zip(names, arrays):
            if verbose:
                print("Loading {}".format(name))
            if re.match('.*(adam_v|adam_m)$', name):
                if verbose:
                    print("Skipping")
            elif name[:4] == "bert":
                name = name[5:]  # skip "bert/"
                name = mapname(name)
                name = name.split('/')
                pointer = self
                for m_name in name:
                    getname = m_name
                    if m_name == "kernel":
                        getname = "weight"
                    pointer = getattr(pointer, getname)
                if m_name == 'kernel':
                    array = np.transpose(array)
                try:
                    assert pointer.shape == array.shape
                except AssertionError as e:
                    e.args += (pointer.shape, array.shape)
                    raise
                pointer.data = torch.from_numpy(array)
            else:
                if verbose:
                    print("Skipping")

        if make_mlm_pred:
            vocsize, dim = self.emb.word_embeddings.weight.shape
            mlm_pred = BERTMLM_Head(dim, vocsize, hidden_act=self.hidden_act)
            if verbose:
                print("Loading MLM prediction model")
            out_weights = self.emb.word_embeddings.weight  # output layer weights tied to embeddings
            mlm_pred.out.weight = out_weights  # tie output weights to embeddings

            # load prefinal dense weight and bias, layernorm and out_bias from tf ckpt
            for name, array in zip(names, arrays):
                if re.match("cls/predictions/.+", name):
                    if verbose:
                        print("Loading {}".format(name))
                    if re.match('.*(adam_v|adam_m)$', name):
                        if verbose:
                            print("Skipping")
                    array = torch.from_numpy(array)
                    if name == "cls/predictions/output_bias":
                        mlm_pred.out.bias.data = array
                    elif re.match("cls/predictions/transform/.+", name):
                        name = name[26:]
                        if name == "LayerNorm/beta":
                            mlm_pred.ln.bias.data = array
                        elif name == "LayerNorm/gamma":
                            mlm_pred.ln.weight.data = array
                        elif name == "dense/kernel":
                            mlm_pred.transform.weight.data = array.t()
                        elif name == "dense/bias":
                            mlm_pred.transform.bias.data = array
                        else:
                            raise q.SumTingWongException(
                                "unknown name: {}".format(name))
                    else:
                        raise q.SumTingWongException(
                            "unknown name: {}".format(name))
            return mlm_pred
コード例 #22
0
def run(
        lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        distill="glove",  # "rnnlm", "glove"
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False,
        repretrain=False,  # retrain base model instead of loading it
        savepath="rnnlm.base.pt",  # where to save after training
        glovepath="../../../data/glove/glove.300d"):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    # region base training
    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    if os.path.exists(savepath) and repretrain is False:
        tt.tick("reloading base model")
        with open(savepath, "rb") as f:
            m = torch.load(f)
            m.to(device)
        tt.tock("reloaded base model")
    else:
        tt.tick("preparing training base")
        dims = [embdim] + ([encdim] * numlayers)

        m = RNNLayer_LM(*dims,
                        worddic=D,
                        dropout=dropout,
                        tieweights=tieweights).to(device)

        if test:
            for i, batch in enumerate(train_batches):
                y = m(batch[0])
                if i > 5:
                    break
            print(y.size())

        optim = torch.optim.SGD(m.parameters(), lr=lr)

        train_batch_f = partial(q.train_batch,
                                on_before_optim_step=[
                                    lambda: torch.nn.utils.clip_grad_norm_(
                                        m.parameters(), gradnorm)
                                ])
        lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                         mode="min",
                                                         factor=1 / 4,
                                                         patience=0,
                                                         verbose=True)
        lrp_f = lambda: lrp.step(validloss.get_epoch_error())

        train_epoch_f = partial(q.train_epoch,
                                model=m,
                                dataloader=train_batches,
                                optim=optim,
                                losses=[loss],
                                device=device,
                                _train_batch=train_batch_f)
        valid_epoch_f = partial(q.test_epoch,
                                model=m,
                                dataloader=valid_batches,
                                losses=validlosses,
                                device=device,
                                on_end=[lrp_f])

        tt.tock("prepared training base")
        tt.tick("training base model")
        q.run_training(train_epoch_f,
                       valid_epoch_f,
                       max_epochs=epochs,
                       validinter=1)
        tt.tock("trained base model")

        with open(savepath, "wb") as f:
            torch.save(m, f)

    tt.tick("testing base model")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested base model")
    # endregion

    # region distillation
    tt.tick("preparing training student")
    dims = [embdim] + ([encdim] * numlayers)
    ms = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                     tieweights=tieweights).to(device)

    loss = q.LossWrapper(q.DistillLoss(temperature=2.))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(ms.parameters(), lr=lr)

    train_batch_f = partial(
        train_batch_distill,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(ms.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    if distill == "rnnlm":
        mbase = m
        goldgetter = None
    elif distill == "glove":
        mbase = None
        tt.tick("creating gold getter based on glove")
        goldgetter = GloveGoldGetter(glovepath, worddic=D)
        goldgetter.to(device)
        tt.tock("created gold getter")
    else:
        raise q.SumTingWongException("unknown distill mode {}".format(distill))

    train_epoch_f = partial(train_epoch_distill,
                            model=ms,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f,
                            mbase=mbase,
                            goldgetter=goldgetter)
    valid_epoch_f = partial(q.test_epoch,
                            model=ms,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("prepared training student")
    tt.tick("training student model")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained student model")

    tt.tick("testing student model")
    testresults = q.test_epoch(model=ms,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested student model")