def __init__(self, version=11, feat_layer=9, **kw): """ Pretrained VGG-*version*, taking only first *feat_layer*-th layer's output. :param version: 11/13/16/19 """ super(SubVGG, self).__init__(**kw) v2f = { 11: torchvision.models.vgg11, 13: torchvision.models.vgg13, 16: torchvision.models.vgg16, 19: torchvision.models.vgg19, } if version not in v2f: raise q.SumTingWongException( "vgg{} does not exist, please specify valid version number (11, 13, 16 or 19)" .format(version)) self.vgg = v2f[version](pretrained=False) if feat_layer > len(self.vgg.features): raise q.SumTingWongException( "vgg{} does not have layer nr. {}. Please use a valid layer number." .format(version, feat_layer)) self.layers = self.vgg.features[:feat_layer] def get_numth(num): numther = {1: "st", 2: "nd", 3: "rd"} if num in numther: return numther[num] else: return "th" print("using VGG{}'s {}{} layer's outputs ({})".format( version, feat_layer, get_numth(feat_layer), str(self.layers[feat_layer - 1])))
def __init__(self, outdic, gen_out, inpdic=None, gen_zero=None, gen_outD=None, **kw): """ :param outdic: output dictionary, must contain all tokens in inpdic and gen_out.D :param gen_prob_comp: module to compute probability of generating vs pointing must produce (batsize, 1) shapes :param gen_out: module to compute generation scores. must have a dictionary accessible as ".D". must produce unnormalized scores (no softmax) :param inpdic: input dictionary (for pointer) :param gen_zero: None or set of tokens for which the gen_out's prob will be set to zero. All tokens should occur in inpdic (or their score will always be zero) :param gen_outD: if set, gen_out must not have a ".D" :param kw: """ super(PointerGeneratorOut, self).__init__(**kw) self.gen_out = gen_out self.D = outdic self.gen_outD = self.gen_out.D if gen_outD is None else gen_outD self.outsize = max(outdic.values()) + 1 self.gen_to_out = q.val( torch.zeros(1, max(self.gen_outD.values()) + 1, dtype=torch.int64)).v # --> where in out to scatter every element of the gen self.gen_zero_mask = None if gen_zero is None else \ q.val(torch.ones_like(self.gen_to_out, dtype=torch.float32)).v # (1, genvocsize), integer ids in outvoc, one-to-one mapping # if symbol in gendic is not in outdic, throws error for k, v in self.gen_outD.items(): if k in outdic: self.gen_to_out[0, v] = outdic[k] if gen_zero is not None: if k in gen_zero: self.gen_zero_mask[0, v] = 0 else: raise q.SumTingWongException( "symbols in gen_outD must be in outdic, but \"{}\" isn't". format(k)) self.inp_to_out = q.val( torch.zeros(max(inpdic.values()) + 1, dtype=torch.int64)).v # --> where in out to scatter every element of the inp # (1, inpvocsize), integer ids in outvoc, one-to-one mapping # if symbol in inpdic is not in outdic, throws error for k, v in inpdic.items(): if k in outdic: self.inp_to_out[v] = outdic[k] else: raise q.SumTingWongException( "symbols in inpdic must be in outdic, but \"{}\" isn't". format(k)) self.sm = torch.nn.Softmax(-1) self._reset() self.check()
def inception_logits(self, inpvar, num_splits=1): images = self.prepare_images(inpvar) generated_images_list = array_ops.split(images, num_or_size_splits=num_splits) if self.inception_version == "default": _fn = functools.partial(tfgan.eval.run_inception, output_tensor='logits:0') else: raise q.SumTingWongException( "use default stuff, this stuff might not be working/working correctly" ) if self.inception_version == "v1": inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz" inception_file = "inception_v1_2016_08_28_frozen.pb" inception_path = os.path.join(self.inception_path, "inception_v1.pb") inception_outvar = "InceptionV1/Logits/SpatialSqueeze:0" inception_invar = "input:0" elif self.inception_version == "v2": inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v2_2016_08_28_frozen.pb.tar.gz" inception_file = "inception_v2_2016_08_28_frozen.pb" inception_path = os.path.join(self.inception_path, "inception_v2.pb") inception_outvar = "InceptionV2/Logits/SpatialSqueeze:0" inception_invar = "input:0" elif self.inception_version == "v3": inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" inception_file = "inception_v3_2016_08_28_frozen.pb" inception_path = os.path.join(self.inception_path, "inception_v3.pb") inception_outvar = "InceptionV3/Logits/SpatialSqueeze:0" inception_invar = "input:0" else: raise q.SumTingWongException( "unknown inception version {}".format( self.inception_version)) graphfn = tfgan.eval.get_graph_def_from_url_tarball( inception_url, inception_file, inception_path) _fn = functools.partial( tfgan.eval.run_inception, graph_def=graphfn, input_tensor=inception_invar, output_tensor=inception_outvar) #'logits:0') logits = functional_ops.map_fn( fn=_fn, elems=array_ops.stack(generated_images_list), parallel_iterations=1, back_prop=False, swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) return logits
def hook(self, f, *es, **kw): """ f to be called when e happens. Returns deleter for bound f can also pass pytorch's lr schedulers if passing a ReduceLROnPlateau, must also pass a function that can be called without arguments and that returns the metric for Reducer """ if isinstance(f, AutoHooker): if len(es) > 0: raise q.SumTingWongException( "can't hook autohooker explicitly on hooks") hookdic = f.get_hooks(self) else: hookdic = dict(zip(es, [f] * len(es))) for e, fe in hookdic.items(): if e not in self._event_callbacks: self._event_callbacks[e] = [] self._event_callbacks[e].append(fe) def deleter(): for e, fe in hookdic.items(): self._event_callbacks[e].remove(fe) # TODO: implement unhooking mechanism return self
def datacat(datasets, mode=1): if mode == 0: return torch.utils.data.dataset.ConcatDataset(datasets) elif mode == 1: return MultiDatasets(datasets) else: raise q.SumTingWongException("mode {} not recognized".format(mode))
def forward(self, x, x_mask=None): z, log_posterior = self.encoder(x, x_mask=x_mask) _x_mask = x_mask[:, 1:] if x_mask is not None else None if z.dim() == 2: log_prior = log_prob_standard_gauss(z) x_hat = self.decoder(x[:, :-1], z=z) elif z.dim() == 3: log_prior = log_prob_seq_standard_gauss(z, mask=_x_mask) x_hat = self.decoder([x[:, :-1], z]) else: raise q.SumTingWongException("z must be 2D or 3D, got {}D".format( z.dim())) log_likelihood = self.likelihood(x_hat, x[:, 1:], x_mask=_x_mask) kl_div = log_posterior - log_prior elbo = log_likelihood - kl_div rets = -elbo, kl_div, -log_likelihood if self._debug: z_grad = torch.autograd.grad(log_likelihood.sum(), z, retain_graph=True) z_grad = z_grad[0]**2 if z.dim() == 3: z_grad = z_grad.sum(2) z_grad = z_grad.sum(1)**0.5 rets = rets + (z_grad, ) return rets
def merge(self, wordemb, mode="sum"): """ Merges this embedding with provided embedding using the provided mode. The dictionary of provided embedding must be identical to this embedding. """ if not wordemb.D == self.D: raise q.SumTingWongException("must have identical dictionary") return MergedWordEmb(self, wordemb, mode=mode)
def __init__(self, base, merge, mode="sum"): super(MergedWordVecBase, self).__init__(base.D) self.base = base self.merg = merge self.mode = mode if not mode in ("sum", "cat"): raise q.SumTingWongException( "{} merge mode not suported".format(mode))
def __init__(self, p=None, prefix=None, **kw): super(Logger, self).__init__() assert(p is None or prefix is None) self.p = p if p is not None else get_default_log_path(prefix) if os.path.exists(self.p): raise q.SumTingWongException("path '{}' already exists".format(p)) else: os.makedirs(self.p) self._current_train_file = None self._current_numbers = [] self.open_liners = {}
def forward(self, x): base_emb, base_msk = self.base(x) merg_emb, merg_msk = self.merg(x) if self.mode == "sum": emb = base_emb + merg_emb msk = base_msk # since dictionaries are identical elif self.mode == "cat": emb = torch.cat([base_emb, merg_emb], 1) msk = base_msk else: raise q.SumTingWongException() return emb, msk
def flatten_chain(chainspec): flatchainspec = [] for x in chainspec: if x in ("+", "-"): flatchainspec.append(x) elif x > -1: relwords = rels[str(x)] flatchainspec += relwords elif x == -1: pass else: raise q.SumTingWongException("unexpected symbol in chain") return " ".join(flatchainspec)
def start_logging(self, names=None, logname=None, overwrite=False, **kw): # make writer if os.path.exists(self.path): if not overwrite: raise q.SumTingWongException("file already exists") else: warnings.warn("training log file already exists. overwriting {}".format(self.path)) self._current_file = open(self.path, "w+") names = ["N."] names += [x.get_name() for x in self.looper.losses.losses] line = "\t".join(names) + "\n" self._current_file.write(line) self._current_file.flush()
def forward(self, x, mask=None, _do_cosnorm=False, _retcosnorm=False, _no_mask_log=False): if self.mode == "cat": # need to split up input basex = x[:, :self.base.vecdim] mergx = x[:, self.base.vecdim:] # TODO: not all wordlinouts have .vecdim elif self.mode == "sum": basex, mergx = x, x else: raise q.SumTingWongException() baseres = self.base(basex, mask=mask, _do_cosnorm=False, _retcosnorm=_retcosnorm or self.cosnorm or _do_cosnorm, _no_mask_log=_no_mask_log) mergres = self.merg(mergx, mask=mask, _do_cosnorm=False, _retcosnorm=_retcosnorm or self.cosnorm or _do_cosnorm, _no_mask_log=_no_mask_log) if _retcosnorm or self.cosnorm or _do_cosnorm: baseres, basecosnorm = baseres mergres, mergcosnorm = mergres cosnorm = basecosnorm + mergcosnorm res = baseres + mergres if _retcosnorm: return res, cosnorm if self.cosnorm or _do_cosnorm: res = res / torch.clamp(torch.norm(x, 2, 1).unsqueeze(1), min=EPS) res = res / torch.clamp(cosnorm, min=EPS).pow(1. / 2) return res
def iscuda(x): if isinstance(x, torch.nn.Module): params = list(x.parameters()) return params[0].is_cuda else: raise q.SumTingWongException("unsupported type")
def _forward(self, scores, gold, mask=None): # (batsize, numvoc), idx^(batsize,) # scores = scores - scores.min() goldscores = torch.gather(scores, 1, gold.unsqueeze(1)).squeeze() if mask is not None and mask[0, 1] > 1: mask = q.batchablesparse2densemask(mask) goldexamplemask = None if self.negmode == "random" or self.negmode == "negall": sampledist = scores.new(scores.size()).to(scores.device) sampledist.fill_(1.) sampledist.scatter_(1, gold.unsqueeze(1), 0) filtermask = scores > -np.infty if mask is not None: filtermask = filtermask & mask.byte() sampledist = sampledist * filtermask.float() sampledist_orig = sampledist if self.margin is not None and self.ignore_below_margin: cutoffs = goldscores - self.margin cutoffmask = scores > cutoffs.unsqueeze(1) sampledist = sampledist * cutoffmask.float() if (sampledist.sum(1) > 0).long().sum() < gold.size(0): # force to sample gold gold_onehot = torch.ByteTensor(sampledist.size()).to( sampledist.device) gold_onehot.fill_(0) gold_onehot.scatter_(1, gold.unsqueeze(1), 1) goldexamplemask = (sampledist.sum(1) != 0) # addtosampledist = sampledist_orig * examplemask.float().unsqueeze(1) addtosampledist = gold_onehot * (~goldexamplemask).unsqueeze(1) sampledist.masked_fill_(addtosampledist, 1) if self.negmode == "random": sample = torch.multinomial(sampledist, 1) negscores = torch.gather(scores, 1, sample).squeeze() elif self.negmode == "negall": negscores = scores * sampledist numnegs = sampledist.sum(1) elif self.negmode == "best": # scores = scores * mask.float() if mask else scores scores = scores + torch.log(mask.float()) if mask else scores bestscores, best = torch.max(scores, 1) secondscores = scores + 0 secondscores.scatter_(1, best.unsqueeze(1), 0) secondbestscores, secondbest = torch.max(secondscores, 1) switchmask = best == gold sample = secondbest * switchmask.long() + best * ( 1 + (-1) * switchmask.long()) negscores = secondbestscores * switchmask.float() + bestscores * ( 1 - switchmask.float()) goldexamplemask = sample.squeeze() != gold # raise NotImplemented("some issues regarding implementation not resolved") else: raise q.SumTingWongException("unknown mode: {}".format( self.negmode)) if self.negmode == "best" or self.negmode == "random": loss = negscores - goldscores if self.margin is not None: loss = torch.clamp(self.margin + loss, min=0) if goldexamplemask is not None: loss = goldexamplemask.float() * loss elif self.negmode == "negall": # negscores are 2D loss = negscores - goldscores.unsqueeze(1) if self.margin is not None: loss = torch.clamp(self.margin + loss, min=0) loss = loss * sampledist loss = loss.sum(1) if self._average_negall: loss = loss / numnegs if goldexamplemask is not None: loss = loss * goldexamplemask.float() ignoremask = self._get_ignore_mask(gold) if ignoremask is not None: loss = loss * ignoremask.float() return loss, ignoremask
def on_end_epoch(self, owner, **kw): if not isinstance(owner, trainer): raise q.SumTingWongException("can only be hooked to a trainer") epoch = owner.current_epoch maxepochs = owner.epochs self.do_epoch(epoch, maxepochs)
def merge(self, x, mode="sum"): x.cosnorm = self.cosnorm if not self.D == x.D: raise q.SumTingWongException() return MergedWordLinout(self, x, mode=mode)
def load_jsons(datap="../../../datasets/lcquad/newdata.json", relp="../../../datasets/lcquad/nrelations.json", mode="flat"): tt = q.ticktock("data loader") tt.tick("loading jsons") data = json.load(open(datap)) rels = json.load(open(relp)) tt.tock("jsons loaded") tt.tick("extracting data") questions = [] goldchains = [] badchains = [] for dataitem in data: questions.append(dataitem["parsed-data"]["corrected_question"]) goldchain = [] for x in dataitem["parsed-data"]["path_id"]: goldchain += [x[0], int(x[1:])] goldchains.append(goldchain) badchainses = [] goldfound = False for badchain in dataitem["uri"]["hop-1-properties"] + dataitem["uri"][ "hop-2-properties"]: if goldchain == badchain: goldfound = True else: if len(badchain) == 2: badchain += [-1, -1] badchainses.append(badchain) badchains.append(badchainses) tt.tock("extracted data") tt.msg("mode: {}".format(mode)) if mode == "flat": tt.tick("flattening") def flatten_chain(chainspec): flatchainspec = [] for x in chainspec: if x in ("+", "-"): flatchainspec.append(x) elif x > -1: relwords = rels[str(x)] flatchainspec += relwords elif x == -1: pass else: raise q.SumTingWongException("unexpected symbol in chain") return " ".join(flatchainspec) goldchainids = [] badchainsids = [] uniquechainids = {} qsm = q.StringMatrix() csm = q.StringMatrix() csm.tokenize = lambda x: x.lower().strip().split() def get_ensure_chainid(flatchain): if flatchain not in uniquechainids: uniquechainids[flatchain] = len(uniquechainids) csm.add(flatchain) assert (len(csm) == len(uniquechainids)) return uniquechainids[flatchain] eid = 0 numchains = 0 for question, goldchain, badchainses in zip(questions, goldchains, badchains): qsm.add(question) # flatten gold chain flatgoldchain = flatten_chain(goldchain) chainid = get_ensure_chainid(flatgoldchain) goldchainids.append(chainid) badchainsids.append([]) numchains += 1 for badchain in badchainses: flatbadchain = flatten_chain(badchain) chainid = get_ensure_chainid(flatbadchain) badchainsids[eid].append(chainid) numchains += 1 eid += 1 tt.live("{}".format(eid)) assert (len(badchainsids) == len(questions)) tt.stoplive() tt.msg("{} unique chains from {} total".format(len(csm), numchains)) qsm.finalize() csm.finalize() tt.tock("flattened") csm.tokenize = None return qsm, csm, goldchainids, badchainsids else: raise q.SumTingWongException("unsupported mode: {}".format(mode))
def logc(self, x, looper, logfilename): """ keep logging state of "x" based on "eventemitter"'s events and store in "logfilename" smart method --> dispatches according to type of x """ raise q.SumTingWongException()