def override(self, emb:WordEmb, selectwords=None): """ :param emb: WordEmb whose entries will override the base (and previous overrides) :param selectwords: which words to override. If None, emb.D's keys are used. :return: """ if hasattr(emb, "word_dropout") and emb.word_dropout is not None: print("WARNING: word dropout of base will be applied before output. " "Word dropout on the emb provided here will be applied before that " "and the combined effect may result in over-dropout.") if selectwords is None: selectwords = set(emb.D.keys()) # ensure that emb.D maps words in self.base.D to the same id as self.base.D selid = len(self.other_embs) + 1 for k, v in self.D.items(): if v > emb.weight.size(0): raise q.SumTingWongException("the override must contain all positions of base.D but doesn't have ('{}':{})".format(k, v)) if k in emb.D and k in selectwords: if emb.D[k] != v: raise q.SumTingWongException("the override emb must map same words to same id " "but '{}' maps to {} in emb.D and to {} in self.base.D" .format(k, emb.D[k], v)) # update select_mask self.select_mask[v] = selid self.other_embs.append(emb) return self
def __init__(self, mode="logits", weight=None, reduction="mean", pos_weight=None, maskid=0, trueid=2, **kw): """ :param mode: "logits" or "probs". If "probs", pos_weight must be None :param weight: :param reduction: :param pos_weight: :param maskid: :param trueid: :param kw: """ super(AutomaskedBCELoss, self).__init__(**kw) self.mode = mode if mode == "logits": self.loss = torch.nn.BCEWithLogitsLoss(weight=weight, reduction="none", pos_weight=pos_weight) elif mode == "probs": assert (pos_weight is None) self.loss = torch.nn.BCELoss(weight=weight, reduction="none") else: raise q.SumTingWongException("unknown mode: {}".format(mode)) self.reduction = reduction self.maskid, self.trueid = maskid, trueid
def __init__(self, base, merge, mode="sum"): super(MergedWordVecBase, self).__init__(base.D) self.base = base self.merg = merge self.mode = mode if not mode in ("sum", "cat"): raise q.SumTingWongException( "{} merge mode not suported".format(mode))
def merge(self, wordemb, mode="sum"): """ Merges this embedding with provided embedding using the provided mode. The dictionary of provided embedding must be identical to this embedding. """ if not wordemb.D == self.D: raise q.SumTingWongException("must have identical dictionary") return MergedWordEmb(self, wordemb, mode=mode)
def add(self, x, y): # x is a string of one example if self._matrix is not None: raise q.SumTingWongException("can't add to finalized {}".format(self.__class__.__name__)) xtokens = self.tokenizer.tokenize(x) ytokens = self.tokenizer.tokenize(y) if len(xtokens) + len(ytokens) > self.maxlen - 3: _truncate_seq_pair(xtokens, ytokens, self.maxlen - 3) tokens = [self.start_token] + xtokens + [self.sep_token] + ytokens + [self.end_token] self._x_maxlen = max(self._x_maxlen, len(tokens)) self.sep_positions.append((len(xtokens) + 1, len(tokens))) self.mattokens.append(tokens)
def datacat(datasets, mode=1): """ Concatenates given pytorch datasets. If mode == 0, creates pytorch ConcatDataset, if mode == 1, creates a MultiDataset. :return: """ if mode == 0: return torch.utils.data.dataset.ConcatDataset(datasets) elif mode == 1: return MultiDatasets(datasets) else: raise q.SumTingWongException("mode {} not recognized".format(mode))
def forward(self, x): base_emb, base_msk = self.base(x) merg_emb, merg_msk = self.merg(x) if self.mode == "sum": emb = base_emb + merg_emb msk = base_msk # since dictionaries are identical elif self.mode == "cat": emb = torch.cat([base_emb, merg_emb], 1) msk = base_msk else: raise q.SumTingWongException() return emb, msk
def add(self, x): # x is a string of one example if self._matrix is not None: raise q.SumTingWongException("can't add to finalized {}".format(self.__class__.__name__)) xtokens = self.tokenizer.tokenize(x) addspace = sum([1 if tok is not None else 0 for tok in [self.start_token, self.end_token]]) if len(xtokens) > self.maxlen - addspace: xtokens = xtokens[:self.maxlen-addspace] if self.start_token is not None: xtokens = [self.start_token] + xtokens if self.end_token is not None: xtokens = xtokens + [self.end_token] self._x_maxlen = max(self._x_maxlen, len(xtokens)) self.mattokens.append(xtokens)
def forward(self, x, mask=None): if self.mode == "cat": # need to split up input basex = x[:, :self.base.vecdim] mergx = x[:, self.base.vecdim:] # TODO: not all wordlinouts have .vecdim elif self.mode == "sum": basex, mergx = x, x else: raise q.SumTingWongException() baseres = self.base(basex, mask=mask) mergres = self.merg(mergx, mask=mask) res = baseres + mergres return res
def forward(self, qry, ctx, ctx_mask=None): """ :param qry: (batsize, dim) or (batsize, zeqlen, dim) :param ctx: (batsize, seqlen, dim) :param ctx_mask: :return: """ if qry.dim() == 2: ret = torch.einsum("bd,bsd->bs", [qry, ctx]) elif qry.dim() == 3: ret = torch.einsum("bzd,bsd->bzs", [qry, ctx]) else: raise q.SumTingWongException( "qry has unsupported dimension: {}".format(qry.dim())) return ret
def run(p="../../../../datasets/webqsp/webqsp.all.butd.vnt.info", outp="../../../../datasets/webqsp/flmats/"): tt = q.ticktock("matrix builder") # load info file tt.tick("loading info") info = pickle.load(open(p)) tt.tock("info loaded") # separate entityinfo = {} relationinfo = {} for key, val in info.items(): if category(key) == ENT: entityinfo[key] = val elif category(key) == REL: relationinfo[key] = val else: raise q.SumTingWongException() # build and save entity matrices edic, names, nameschars, aliases, typenames, notabletypenames, types \ = build_entity_matrices(entityinfo) pickle.dump(edic, open(outp + "webqsp.entity.dic", "w")) names.save(outp + "webqsp.entity.names.sm") nameschars.save(outp + "webqsp.entity.names.char.sm") # aliases.save(outp+"webqsp.entity.aliases.sm") typenames.save(outp + "webqsp.entity.typenames.sm") notabletypenames.save(outp + "webqsp.entity.notabletypes.sm") types.save(outp + "webqsp.entity.types.sm") # build and save relation matrices rdic, names, domains, ranges, domainids, rangeids, urlwords, urltokens \ = build_relation_matrices(relationinfo) pickle.dump(rdic, open(outp + "webqsp.relation.dic", "w")) basep = outp + "webqsp.relation." names.save(basep + "names.sm") domains.save(basep + "domains.sm") ranges.save(basep + "ranges.sm") domainids.save(basep + "domainids.sm") rangeids.save(basep + "rangeids.sm") urlwords.save(basep + "urlwords.sm") urltokens.save(basep + "urltokens.sm") # reload tt.tick("reloading") enamesreloaded = q.StringMatrix.load(outp + "webqsp.entity.typenames.sm") tt.tock("reloaded")
def get_logprob_of_sampled_alphas(self): if self.hard is False: raise q.SumTingWongException( "Use this only for RL on hard attention (must be in hard mode)." ) probs = self.prevatts_probs * self.prevatts_samples + ( 1 - self.prevatts_probs) * (1 - self.prevatts_samples) logprobs = torch.log(probs) logprobs = logprobs * self.prevatts_mask # mask the logprobs average_within_timestep = True if average_within_timestep: totals = self.prevatts_mask.sum(2) + 1e-6 logprobs = logprobs.sum(2) / totals else: logprobs = logprobs.mean(2) return logprobs[:, 2:] # (batsize, seqlen) -- decoder mask should be applied on this later
def __init__(self, weight=None, reduction="mean", ignore_index=-100, mode="logits", **kw): super(CELoss, self).__init__(**kw) self.mode = mode if mode in ("logprobs", "probs"): self.ce = torch.nn.NLLLoss(weight=weight, reduction=reduction, ignore_index=ignore_index) elif mode == "logits": self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction, ignore_index=ignore_index) else: raise q.SumTingWongException("unknown mode {}".format(mode))
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100, attmode="bilin", decsplit=False, **kw): """ makes decoder # attention cell decoder that accepts VNT !!! """ ctxdim = ctxdim if not decsplit else ctxdim // 2 coreindim = embdim + ctxdim # if ctx_to_decinp is True else embdim coretocritdim = dim if not decsplit else dim // 2 critdim = dim + embdim # if decinp_to_att is True else dim if attmode == "bilin": attention = q.Attention().bilinear_gen(ctxdim, critdim) elif attmode == "fwd": attention = q.Attention().forward_gen(ctxdim, critdim) else: raise q.SumTingWongException() attcell = q.AttentionDecoderCell(attention=attention, embedder=emb, core=q.RecStack( q.GRUCell(coreindim, dim), q.GRUCell(dim, dim), ), smo=q.Stack( q.argsave.spec(mask={"mask"}), lin, q.argmap.spec(0, mask=["mask"]), q.LogSoftmax(), q.argmap.spec(0), ), ctx_to_decinp=True, ctx_to_smo=True, state_to_smo=True, decinp_to_att=True, state_split=decsplit) return attcell.to_decoder()
def gen_datasets(which="geo"): pprefix = "../data/" if which == "geo": pprefix = pprefix + "geoqueries/jia2016/" trainp = pprefix + "train.txt" validp = pprefix + "test.txt" testp = pprefix + "test.txt" elif which == "atis": pprefix += "atis/jia2016/" trainp = pprefix + "train.txt" validp = pprefix + "dev.txt" testp = pprefix + "test.txt" elif which == "jobs": assert(False) # jia didn't do jobs pprefix += "jobqueries" trainp = pprefix + "train.txt" validp = pprefix + "test.txt" testp = pprefix + "test.txt" else: raise q.SumTingWongException("unknown dataset") nlsm = q.StringMatrix(indicate_start_end=True) nlsm.tokenize = lambda x: x.split() flsm = q.StringMatrix(indicate_start_end=True if which == "jobs" else False) flsm.tokenize = lambda x: x.split() devstart, teststart, i = 0, 0, 0 trainwords = set() trainwordcounts = {} testwords = set() trainwords_fl = set() trainwordcounts_fl = {} testwords_fl = set() with open(trainp) as tf, open(validp) as vf, open(testp) as xf: for line in tf: line_nl, line_fl = line.strip().split("\t") line_fl = line_fl.replace("' ", "") # line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) trainwords |= set(line_nl.split()) for word in set(line_nl.split()): if word not in trainwordcounts: trainwordcounts[word] = 0 trainwordcounts[word] += 1 trainwords_fl |= set(line_fl.split()) for word in set(line_fl.split()): if word not in trainwordcounts_fl: trainwordcounts_fl[word] = 0 trainwordcounts_fl[word] += 1 i += 1 devstart = i for line in vf: line_nl, line_fl = line.strip().split("\t") line_fl = line_fl.replace("' ", "") # line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) i += 1 teststart = i for line in xf: line_nl, line_fl = line.strip().split("\t") line_fl = line_fl.replace("' ", "") # line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) testwords |= set(line_nl.split()) testwords_fl |= set(line_fl.split()) i += 1 nlsm.finalize() flsm.finalize() # region get gate sup gatesups = torch.zeros(flsm.matrix.shape[0], flsm.matrix.shape[1]+1, dtype=torch.long) for i in range(nlsm.matrix.shape[0]): nl_sent = nlsm[i].split() fl_sent = flsm[i].split() inid = False for j, fl_sent_token in enumerate(fl_sent): if re.match("_\w+id", fl_sent_token): inid = True elif fl_sent_token == ")": inid = False elif fl_sent_token == "(": pass else: if inid: if fl_sent_token in nl_sent: gatesups[i, j] = 1 # endregion # region print analysis print("{} unique words in train, {} unique words in test, {} in test but not in train" .format(len(trainwords), len(testwords), len(testwords - trainwords))) print(testwords - trainwords) trainwords_once = set([k for k, v in trainwordcounts.items() if v < 2]) print("{} unique words in train that occur only once ({} of them is in test)".format(len(trainwords_once), len(trainwords_once & testwords))) print(trainwords_once) trainwords_twice = set([k for k, v in trainwordcounts.items() if v < 3]) print("{} unique words in train that occur only twice ({} of them is in test)".format(len(trainwords_twice), len(trainwords_twice & testwords))) rarerep = trainwords_once | (testwords - trainwords) print("{} unique rare representation words".format(len(rarerep))) print(rarerep) trainwords_fl_once = set([k for k, v in trainwordcounts_fl.items() if v < 2]) rarerep_fl = trainwords_fl_once | (testwords_fl - trainwords_fl) print("{} unique rare rep words in logical forms".format(len(rarerep_fl))) print(rarerep_fl) # endregion # endregion create datasets nlmat = torch.tensor(nlsm.matrix).long() flmat = torch.tensor(flsm.matrix).long() gold = torch.tensor(flsm.matrix[:, 1:]).long() gold = torch.cat([gold, torch.zeros_like(gold[:, 0:1])], 1) tds = torch.utils.data.TensorDataset(nlmat[:devstart], flmat[:devstart], gold[:devstart], gatesups[:devstart][:, 1:]) vds = torch.utils.data.TensorDataset(nlmat[devstart:teststart], flmat[devstart:teststart], gold[devstart:teststart]) xds = torch.utils.data.TensorDataset(nlmat[teststart:], flmat[teststart:], gold[teststart:]) # endregion return (tds, vds, xds), nlsm.D, flsm.D, rarerep, rarerep_fl
def gen_datasets(which="geo"): pprefix = "../data/" if which == "geo": pprefix = pprefix + "geoqueries/dong2016/" trainp = pprefix + "train.txt" validp = pprefix + "test.txt" testp = pprefix + "test.txt" elif which == "atis": pprefix += "atis/dong2016/" trainp = pprefix + "train.txt" validp = pprefix + "dev.txt" testp = pprefix + "test.txt" elif which == "jobs": pprefix += "jobqueries/dong2016/" trainp = pprefix + "train.txt" validp = pprefix + "test.txt" testp = pprefix + "test.txt" else: raise q.SumTingWongException("unknown dataset") nlsm = q.StringMatrix(indicate_start_end=True) nlsm.tokenize = lambda x: x.split() flsm = q.StringMatrix( indicate_start_end=True if which == "jobs" else False) flsm.tokenize = lambda x: x.split() devstart, teststart, i = 0, 0, 0 with open(trainp) as tf, open(validp) as vf, open(testp) as xf: for line in tf: line_nl, line_fl = line.strip().split("\t") line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) i += 1 devstart = i for line in vf: line_nl, line_fl = line.strip().split("\t") line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) i += 1 teststart = i for line in xf: line_nl, line_fl = line.strip().split("\t") line_nl = " ".join(line_nl.split(" ")[::-1]) nlsm.add(line_nl) flsm.add(line_fl) i += 1 nlsm.finalize() flsm.finalize() nlmat = torch.tensor(nlsm.matrix).long() flmat = torch.tensor(flsm.matrix).long() gold = torch.tensor(flsm.matrix[:, 1:]).long() gold = torch.cat([gold, torch.zeros_like(gold[:, 0:1])], 1) tds = torch.utils.data.TensorDataset(nlmat[:devstart], flmat[:devstart], gold[:devstart]) vds = torch.utils.data.TensorDataset(nlmat[devstart:teststart], flmat[devstart:teststart], gold[devstart:teststart]) xds = torch.utils.data.TensorDataset(nlmat[teststart:], flmat[teststart:], gold[teststart:]) return (tds, vds, xds), nlsm.D, flsm.D
def iscuda(x): if isinstance(x, torch.nn.Module): params = list(x.parameters()) return params[0].is_cuda else: raise q.SumTingWongException("unsupported type")
def train_batch_distill(batch=None, model=None, optim=None, losses=None, device=torch.device("cpu"), batch_number=-1, max_batches=0, current_epoch=0, max_epochs=0, on_start=tuple(), on_before_optim_step=tuple(), on_after_optim_step=tuple(), on_end=tuple(), run=False, mbase=None, goldgetter=None): """ Runs a single batch of SGD on provided batch and settings. :param _batch: batch to run on :param model: torch.nn.Module of the model :param optim: torch optimizer :param losses: list of losswrappers :param device: device :param batch_number: which batch :param max_batches: total number of batches :param current_epoch: current epoch :param max_epochs: total number of epochs :param on_start: collection of functions to call when starting training batch :param on_before_optim_step: collection of functions for before optimization step is taken (gradclip) :param on_after_optim_step: collection of functions for after optimization step is taken :param on_end: collection of functions to call when batch is done :param mbase: base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used. :param goldgetter: takes the gold and produces a softgold :return: """ # if run is False: # kwargs = locals().copy() # return partial(train_batch, **kwargs) [e() for e in on_start] optim.zero_grad() model.train() batch = (batch, ) if not q.issequence(batch) else batch batch = q.recmap( batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) batch_in = batch[:-1] gold = batch[-1] # run batch_in through teacher model to get teacher output distributions if goldgetter is not None: softgold = goldgetter(gold) elif mbase is not None: mbase.eval() q.batch_reset(mbase) with torch.no_grad(): softgold = mbase(*batch_in) else: raise q.SumTingWongException( "goldgetter and mbase can not both be None") q.batch_reset(model) modelouts = model(*batch_in) trainlosses = [] for loss_obj in losses: loss_val = loss_obj(modelouts, (softgold, gold)) loss_val = [loss_val] if not q.issequence(loss_val) else loss_val trainlosses.extend(loss_val) cost = trainlosses[0] cost.backward() [e() for e in on_before_optim_step] optim.step() [e() for e in on_after_optim_step] ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format( current_epoch + 1, max_epochs, batch_number + 1, max_batches, q.pp_epoch_losses(*losses), ) [e() for e in on_end] return ttmsg
def merge(self, x, mode="sum"): if not self.D == x.D: raise q.SumTingWongException() return MergedWordLinout(self, x, mode=mode)
def load_data(p="../../data/buboqa/data/bertified_dataset.npz", which="span/io", retrelD=False, retrelcounts=False, rettokD=False, datafrac=1., wordlevel=False): """ :param p: where the stored matrices are :param which: which data to include in output datasets "span/io": O/I annotated spans, "span/borders": begin and end positions of span "rel+io": what relation (also gives "spanio" outputs to give info where entity is supposed to be (to ignore it)) "rel+borders": same, but gives "spanborders" instead "all": everything :return: """ tt = q.ticktock("dataloader") if wordlevel: tt.tick("loading original data word-level stringmatrix") wordmat, wordD, (word_devstart, word_teststart) = load_word_mat() twordmat, vwordmat, xwordmat = wordmat[:word_devstart], wordmat[ word_devstart:word_teststart], wordmat[word_teststart:] tt.tock("loaded stringmatrix") tt.tick("loading saved np mats") data = np.load(p) print(data.keys()) relD = data["relD"].item() revrelD = {v: k for k, v in relD.items()} devstart = data["devstart"] teststart = data["teststart"] if wordlevel: assert (devstart == word_devstart) assert (teststart == word_teststart) tt.tock("mats loaded") tt.tick("loading BERT tokenizer") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tt.tock("done") if wordlevel: tokD = wordD else: tokD = tokenizer.vocab def pp(i): tokrow = data["tokmat"][i] iorow = [xe - 1 for xe in data["iomat"][i] if xe != 0] ioborderrow = data["ioborders"][i] rel_i = data["rels"][i] tokrow = tokenizer.convert_ids_to_tokens( [tok for tok in tokrow if tok != 0]) # print(" ".join(tokrow)) print(tabulate([range(len(tokrow)), tokrow, iorow])) print(ioborderrow) print(revrelD[rel_i]) # tt.tick("printing some examples") # for k in range(10): # print("\nExample {}".format(k)) # pp(k) # tt.tock("printed some examples") # datasets tt.tick("making datasets") if which == "span/io": selection = ["tokmat", "iomat"] elif which == "span/borders": selection = ["tokmat", "ioborders"] elif which == "rel+io": selection = ["tokmat", "iomat", "rels"] elif which == "rel+borders": selection = ["tokmat", "ioborders", "rels"] elif which == "all": selection = ["tokmat", "iomat", "ioborders", "rels"] else: raise q.SumTingWongException("unknown which mode: {}".format(which)) if wordlevel: tokmat = wordmat else: tokmat = data["tokmat"] selected = [ torch.tensor(data[sel] if sel != "tokmat" else tokmat).long() for sel in selection ] tselected = [sel[:devstart] for sel in selected] vselected = [sel[devstart:teststart] for sel in selected] xselected = [sel[teststart:] for sel in selected] if datafrac <= 1.: # restrict data such that least relations are unseen # get relation counts trainrels = data["rels"][:devstart] uniquerels, relcounts = np.unique(data["rels"][:devstart], return_counts=True) relcountsD = dict(zip(uniquerels, relcounts)) relcounter = dict(zip(uniquerels, [0] * len(uniquerels))) totalcap = int(datafrac * len(trainrels)) capperrel = max(relcountsD.values()) def numberexamplesincluded(capperrel_): numberexamplesforcap = np.clip(relcounts, 0, capperrel_).sum() return numberexamplesforcap while capperrel > 0: # TODO do binary search numexcapped = numberexamplesincluded(capperrel) if numexcapped <= totalcap: break capperrel -= 1 print("rel count cap is {}".format(capperrel)) remainids = [] for i in range(len(trainrels)): if len(remainids) >= totalcap: break if relcounter[trainrels[i]] > capperrel: pass else: relcounter[trainrels[i]] += 1 remainids.append(i) print("{}/{} examples retained".format(len(remainids), len(trainrels))) tselected_new = [sel[remainids] for sel in tselected] if datafrac == 1.: for a, b in zip(tselected_new, tselected): assert (np.all(a == b)) tselected = tselected_new traindata = TensorDataset(*tselected) devdata = TensorDataset(*vselected) testdata = TensorDataset(*xselected) ret = (traindata, devdata, testdata) if retrelD: ret += (relD, ) if rettokD: ret += (tokD, ) if retrelcounts: ret += data["relcounts"] tt.tock("made datasets") return ret
def load_weights_from_tf_checkpoint(self, ckpt_path, make_mlm_pred=False, verbose=True): if verbose: print("Loading tensorflow BERT weights from {}".format(ckpt_path)) import tensorflow as tf # region from Hugging Face BERT init_vars = tf.train.list_variables(ckpt_path) names = [] arrays = [] for name, shape in init_vars: if verbose: print("Loading {} with shape {}".format(name, shape)) array = tf.train.load_variable(ckpt_path, name) if verbose: print("Numpy array shape {}".format(array.shape)) names.append(name) arrays.append(array) # endregion # load values from tf ckpt variable paths to our paths def mapname(a): """ LayerNorm beta -> bias gamma -> weight (.+)_embeddings -> {1}_embeddings.weight /embeddings -> emb /encoder layer_(\d+) -> layers,{1} attention/output/LayerNorm -> ln_slf attention -> slf_attn self/query -> q_proj self/key -> k_proj self/value -> v_proj output/dense -> vw_proj intermediate/dense -> mlp.projA output/dense -> mlp.projB output/LayerNorm -> ln_ff """ if re.match(".+LayerNorm.+", a): a = re.sub("LayerNorm/gamma$", "LayerNorm/weight", a) a = re.sub("LayerNorm/beta$", "LayerNorm/bias", a) if re.match(".+_embeddings$", a): a = re.sub("(.+_embeddings)$", "\g<1>/weight", a) # a = re.sub("kernel$", "weight", a) a = re.sub("^embeddings", "emb", a) if re.match("^encoder", a): if re.match("^encoder/layer_\d+", a): a = re.sub("^(encoder/layer)_(\d+)", "encoder/layers/\g<2>", a) a = re.sub("attention/output/LayerNorm", "ln_slf", a) if re.match(".+attention.+", a): a = re.sub("attention/self/query", "attention/q_proj", a) a = re.sub("attention/self/key", "attention/k_proj", a) a = re.sub("attention/self/value", "attention/v_proj", a) a = re.sub("attention/output/dense", "attention/vw_proj", a) a = re.sub("attention", "slf_attn", a) a = re.sub("intermediate/dense", "mlp/projA", a) a = re.sub("output/dense", "mlp/projB", a) a = re.sub("output/LayerNorm", "ln_ff", a) return a for name, array in zip(names, arrays): if verbose: print("Loading {}".format(name)) if re.match('.*(adam_v|adam_m)$', name): if verbose: print("Skipping") elif name[:4] == "bert": name = name[5:] # skip "bert/" name = mapname(name) name = name.split('/') pointer = self for m_name in name: getname = m_name if m_name == "kernel": getname = "weight" pointer = getattr(pointer, getname) if m_name == 'kernel': array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: e.args += (pointer.shape, array.shape) raise pointer.data = torch.from_numpy(array) else: if verbose: print("Skipping") if make_mlm_pred: vocsize, dim = self.emb.word_embeddings.weight.shape mlm_pred = BERTMLM_Head(dim, vocsize, hidden_act=self.hidden_act) if verbose: print("Loading MLM prediction model") out_weights = self.emb.word_embeddings.weight # output layer weights tied to embeddings mlm_pred.out.weight = out_weights # tie output weights to embeddings # load prefinal dense weight and bias, layernorm and out_bias from tf ckpt for name, array in zip(names, arrays): if re.match("cls/predictions/.+", name): if verbose: print("Loading {}".format(name)) if re.match('.*(adam_v|adam_m)$', name): if verbose: print("Skipping") array = torch.from_numpy(array) if name == "cls/predictions/output_bias": mlm_pred.out.bias.data = array elif re.match("cls/predictions/transform/.+", name): name = name[26:] if name == "LayerNorm/beta": mlm_pred.ln.bias.data = array elif name == "LayerNorm/gamma": mlm_pred.ln.weight.data = array elif name == "dense/kernel": mlm_pred.transform.weight.data = array.t() elif name == "dense/bias": mlm_pred.transform.bias.data = array else: raise q.SumTingWongException( "unknown name: {}".format(name)) else: raise q.SumTingWongException( "unknown name: {}".format(name)) return mlm_pred
def run( lr=20., dropout=0.2, dropconnect=0.2, gradnorm=0.25, epochs=25, embdim=200, encdim=200, numlayers=2, tieweights=False, distill="glove", # "rnnlm", "glove" seqlen=35, batsize=20, eval_batsize=80, cuda=False, gpu=0, test=False, repretrain=False, # retrain base model instead of loading it savepath="rnnlm.base.pt", # where to save after training glovepath="../../../data/glove/glove.300d"): tt = q.ticktock("script") device = torch.device("cpu") if cuda: device = torch.device("cuda", gpu) tt.tick("loading data") train_batches, valid_batches, test_batches, D = \ load_data(batsize=batsize, eval_batsize=eval_batsize, seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0)) tt.tock("data loaded") print("{} batches in train".format(len(train_batches))) # region base training loss = q.LossWrapper(q.CELoss(mode="logits")) validloss = q.LossWrapper(q.CELoss(mode="logits")) validlosses = [validloss, PPLfromCE(validloss)] testloss = q.LossWrapper(q.CELoss(mode="logits")) testlosses = [testloss, PPLfromCE(testloss)] for l in [loss] + validlosses + testlosses: # put losses on right device l.loss.to(device) if os.path.exists(savepath) and repretrain is False: tt.tick("reloading base model") with open(savepath, "rb") as f: m = torch.load(f) m.to(device) tt.tock("reloaded base model") else: tt.tick("preparing training base") dims = [embdim] + ([encdim] * numlayers) m = RNNLayer_LM(*dims, worddic=D, dropout=dropout, tieweights=tieweights).to(device) if test: for i, batch in enumerate(train_batches): y = m(batch[0]) if i > 5: break print(y.size()) optim = torch.optim.SGD(m.parameters(), lr=lr) train_batch_f = partial(q.train_batch, on_before_optim_step=[ lambda: torch.nn.utils.clip_grad_norm_( m.parameters(), gradnorm) ]) lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="min", factor=1 / 4, patience=0, verbose=True) lrp_f = lambda: lrp.step(validloss.get_epoch_error()) train_epoch_f = partial(q.train_epoch, model=m, dataloader=train_batches, optim=optim, losses=[loss], device=device, _train_batch=train_batch_f) valid_epoch_f = partial(q.test_epoch, model=m, dataloader=valid_batches, losses=validlosses, device=device, on_end=[lrp_f]) tt.tock("prepared training base") tt.tick("training base model") q.run_training(train_epoch_f, valid_epoch_f, max_epochs=epochs, validinter=1) tt.tock("trained base model") with open(savepath, "wb") as f: torch.save(m, f) tt.tick("testing base model") testresults = q.test_epoch(model=m, dataloader=test_batches, losses=testlosses, device=device) print(testresults) tt.tock("tested base model") # endregion # region distillation tt.tick("preparing training student") dims = [embdim] + ([encdim] * numlayers) ms = RNNLayer_LM(*dims, worddic=D, dropout=dropout, tieweights=tieweights).to(device) loss = q.LossWrapper(q.DistillLoss(temperature=2.)) validloss = q.LossWrapper(q.CELoss(mode="logits")) validlosses = [validloss, PPLfromCE(validloss)] testloss = q.LossWrapper(q.CELoss(mode="logits")) testlosses = [testloss, PPLfromCE(testloss)] for l in [loss] + validlosses + testlosses: # put losses on right device l.loss.to(device) optim = torch.optim.SGD(ms.parameters(), lr=lr) train_batch_f = partial( train_batch_distill, on_before_optim_step=[ lambda: torch.nn.utils.clip_grad_norm_(ms.parameters(), gradnorm) ]) lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="min", factor=1 / 4, patience=0, verbose=True) lrp_f = lambda: lrp.step(validloss.get_epoch_error()) if distill == "rnnlm": mbase = m goldgetter = None elif distill == "glove": mbase = None tt.tick("creating gold getter based on glove") goldgetter = GloveGoldGetter(glovepath, worddic=D) goldgetter.to(device) tt.tock("created gold getter") else: raise q.SumTingWongException("unknown distill mode {}".format(distill)) train_epoch_f = partial(train_epoch_distill, model=ms, dataloader=train_batches, optim=optim, losses=[loss], device=device, _train_batch=train_batch_f, mbase=mbase, goldgetter=goldgetter) valid_epoch_f = partial(q.test_epoch, model=ms, dataloader=valid_batches, losses=validlosses, device=device, on_end=[lrp_f]) tt.tock("prepared training student") tt.tick("training student model") q.run_training(train_epoch_f, valid_epoch_f, max_epochs=epochs, validinter=1) tt.tock("trained student model") tt.tick("testing student model") testresults = q.test_epoch(model=ms, dataloader=test_batches, losses=testlosses, device=device) print(testresults) tt.tock("tested student model")