def __init__(self, version=11, feat_layer=9, **kw):
        """
        Pretrained VGG-*version*, taking only first *feat_layer*-th layer's output.
        :param version:     11/13/16/19
        """
        super(SubVGG, self).__init__(**kw)
        v2f = {
            11: torchvision.models.vgg11,
            13: torchvision.models.vgg13,
            16: torchvision.models.vgg16,
            19: torchvision.models.vgg19,
        }
        if version not in v2f:
            raise q.SumTingWongException(
                "vgg{} does not exist, please specify valid version number (11, 13, 16 or 19)"
                .format(version))
        self.vgg = v2f[version](pretrained=False)
        if feat_layer > len(self.vgg.features):
            raise q.SumTingWongException(
                "vgg{} does not have layer nr. {}. Please use a valid layer number."
                .format(version, feat_layer))
        self.layers = self.vgg.features[:feat_layer]

        def get_numth(num):
            numther = {1: "st", 2: "nd", 3: "rd"}
            if num in numther:
                return numther[num]
            else:
                return "th"

        print("using VGG{}'s {}{} layer's outputs ({})".format(
            version, feat_layer, get_numth(feat_layer),
            str(self.layers[feat_layer - 1])))
    def __init__(self,
                 outdic,
                 gen_out,
                 inpdic=None,
                 gen_zero=None,
                 gen_outD=None,
                 **kw):
        """
                :param outdic:          output dictionary, must contain all tokens in inpdic and gen_out.D
                :param gen_prob_comp:   module to compute probability of generating vs pointing
                                        must produce (batsize, 1) shapes
                :param gen_out:         module to compute generation scores.
                                            must have a dictionary accessible as ".D".
                                            must produce unnormalized scores (no softmax)
                :param inpdic:          input dictionary (for pointer)
                :param gen_zero:        None or set of tokens for which the gen_out's prob will be set to zero.
                                        All tokens should occur in inpdic (or their score will always be zero)
                :param gen_outD:        if set, gen_out must not have a ".D"
                :param kw:
                """
        super(PointerGeneratorOut, self).__init__(**kw)
        self.gen_out = gen_out
        self.D = outdic
        self.gen_outD = self.gen_out.D if gen_outD is None else gen_outD
        self.outsize = max(outdic.values()) + 1
        self.gen_to_out = q.val(
            torch.zeros(1, max(self.gen_outD.values()) + 1,
                        dtype=torch.int64)).v
        # --> where in out to scatter every element of the gen
        self.gen_zero_mask = None if gen_zero is None else \
            q.val(torch.ones_like(self.gen_to_out, dtype=torch.float32)).v
        # (1, genvocsize), integer ids in outvoc, one-to-one mapping
        # if symbol in gendic is not in outdic, throws error
        for k, v in self.gen_outD.items():
            if k in outdic:
                self.gen_to_out[0, v] = outdic[k]
                if gen_zero is not None:
                    if k in gen_zero:
                        self.gen_zero_mask[0, v] = 0
            else:
                raise q.SumTingWongException(
                    "symbols in gen_outD must be in outdic, but \"{}\" isn't".
                    format(k))

        self.inp_to_out = q.val(
            torch.zeros(max(inpdic.values()) + 1, dtype=torch.int64)).v
        # --> where in out to scatter every element of the inp
        # (1, inpvocsize), integer ids in outvoc, one-to-one mapping
        # if symbol in inpdic is not in outdic, throws error
        for k, v in inpdic.items():
            if k in outdic:
                self.inp_to_out[v] = outdic[k]
            else:
                raise q.SumTingWongException(
                    "symbols in inpdic must be in outdic, but \"{}\" isn't".
                    format(k))
        self.sm = torch.nn.Softmax(-1)
        self._reset()
        self.check()
示例#3
0
    def inception_logits(self, inpvar, num_splits=1):
        images = self.prepare_images(inpvar)
        generated_images_list = array_ops.split(images,
                                                num_or_size_splits=num_splits)
        if self.inception_version == "default":
            _fn = functools.partial(tfgan.eval.run_inception,
                                    output_tensor='logits:0')
        else:
            raise q.SumTingWongException(
                "use default stuff, this stuff might not be working/working correctly"
            )
            if self.inception_version == "v1":
                inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz"
                inception_file = "inception_v1_2016_08_28_frozen.pb"
                inception_path = os.path.join(self.inception_path,
                                              "inception_v1.pb")
                inception_outvar = "InceptionV1/Logits/SpatialSqueeze:0"
                inception_invar = "input:0"
            elif self.inception_version == "v2":
                inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v2_2016_08_28_frozen.pb.tar.gz"
                inception_file = "inception_v2_2016_08_28_frozen.pb"
                inception_path = os.path.join(self.inception_path,
                                              "inception_v2.pb")
                inception_outvar = "InceptionV2/Logits/SpatialSqueeze:0"
                inception_invar = "input:0"
            elif self.inception_version == "v3":
                inception_url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz"
                inception_file = "inception_v3_2016_08_28_frozen.pb"
                inception_path = os.path.join(self.inception_path,
                                              "inception_v3.pb")
                inception_outvar = "InceptionV3/Logits/SpatialSqueeze:0"
                inception_invar = "input:0"
            else:
                raise q.SumTingWongException(
                    "unknown inception version {}".format(
                        self.inception_version))

            graphfn = tfgan.eval.get_graph_def_from_url_tarball(
                inception_url, inception_file, inception_path)
            _fn = functools.partial(
                tfgan.eval.run_inception,
                graph_def=graphfn,
                input_tensor=inception_invar,
                output_tensor=inception_outvar)  #'logits:0')

        logits = functional_ops.map_fn(
            fn=_fn,
            elems=array_ops.stack(generated_images_list),
            parallel_iterations=1,
            back_prop=False,
            swap_memory=True,
            name='RunClassifier')
        logits = array_ops.concat(array_ops.unstack(logits), 0)
        return logits
示例#4
0
    def hook(self, f, *es, **kw):
        """ f to be called when e happens. Returns deleter for bound f
            can also pass pytorch's lr schedulers
            if passing a ReduceLROnPlateau, must also pass a function that can be called without arguments
                and that returns the metric for Reducer
        """
        if isinstance(f, AutoHooker):
            if len(es) > 0:
                raise q.SumTingWongException(
                    "can't hook autohooker explicitly on hooks")
            hookdic = f.get_hooks(self)
        else:
            hookdic = dict(zip(es, [f] * len(es)))

        for e, fe in hookdic.items():
            if e not in self._event_callbacks:
                self._event_callbacks[e] = []
            self._event_callbacks[e].append(fe)

        def deleter():
            for e, fe in hookdic.items():
                self._event_callbacks[e].remove(fe)

        # TODO: implement unhooking mechanism
        return self
示例#5
0
def datacat(datasets, mode=1):
    if mode == 0:
        return torch.utils.data.dataset.ConcatDataset(datasets)
    elif mode == 1:
        return MultiDatasets(datasets)
    else:
        raise q.SumTingWongException("mode {} not recognized".format(mode))
示例#6
0
 def forward(self, x, x_mask=None):
     z, log_posterior = self.encoder(x, x_mask=x_mask)
     _x_mask = x_mask[:, 1:] if x_mask is not None else None
     if z.dim() == 2:
         log_prior = log_prob_standard_gauss(z)
         x_hat = self.decoder(x[:, :-1], z=z)
     elif z.dim() == 3:
         log_prior = log_prob_seq_standard_gauss(z, mask=_x_mask)
         x_hat = self.decoder([x[:, :-1], z])
     else:
         raise q.SumTingWongException("z must be 2D or 3D, got {}D".format(
             z.dim()))
     log_likelihood = self.likelihood(x_hat, x[:, 1:], x_mask=_x_mask)
     kl_div = log_posterior - log_prior
     elbo = log_likelihood - kl_div
     rets = -elbo, kl_div, -log_likelihood
     if self._debug:
         z_grad = torch.autograd.grad(log_likelihood.sum(),
                                      z,
                                      retain_graph=True)
         z_grad = z_grad[0]**2
         if z.dim() == 3:
             z_grad = z_grad.sum(2)
         z_grad = z_grad.sum(1)**0.5
         rets = rets + (z_grad, )
     return rets
示例#7
0
 def merge(self, wordemb, mode="sum"):
     """
     Merges this embedding with provided embedding using the provided mode.
     The dictionary of provided embedding must be identical to this embedding.
     """
     if not wordemb.D == self.D:
         raise q.SumTingWongException("must have identical dictionary")
     return MergedWordEmb(self, wordemb, mode=mode)
示例#8
0
 def __init__(self, base, merge, mode="sum"):
     super(MergedWordVecBase, self).__init__(base.D)
     self.base = base
     self.merg = merge
     self.mode = mode
     if not mode in ("sum", "cat"):
         raise q.SumTingWongException(
             "{} merge mode not suported".format(mode))
示例#9
0
 def __init__(self, p=None, prefix=None, **kw):
     super(Logger, self).__init__()
     assert(p is None or prefix is None)
     self.p = p if p is not None else get_default_log_path(prefix)
     if os.path.exists(self.p):
         raise q.SumTingWongException("path '{}' already exists".format(p))
     else:
         os.makedirs(self.p)
     self._current_train_file = None
     self._current_numbers = []
     self.open_liners = {}
示例#10
0
 def forward(self, x):
     base_emb, base_msk = self.base(x)
     merg_emb, merg_msk = self.merg(x)
     if self.mode == "sum":
         emb = base_emb + merg_emb
         msk = base_msk  # since dictionaries are identical
     elif self.mode == "cat":
         emb = torch.cat([base_emb, merg_emb], 1)
         msk = base_msk
     else:
         raise q.SumTingWongException()
     return emb, msk
示例#11
0
 def flatten_chain(chainspec):
     flatchainspec = []
     for x in chainspec:
         if x in ("+", "-"):
             flatchainspec.append(x)
         elif x > -1:
             relwords = rels[str(x)]
             flatchainspec += relwords
         elif x == -1:
             pass
         else:
             raise q.SumTingWongException("unexpected symbol in chain")
     return " ".join(flatchainspec)
示例#12
0
 def start_logging(self, names=None, logname=None, overwrite=False, **kw):
     # make writer
     if os.path.exists(self.path):
         if not overwrite:
             raise q.SumTingWongException("file already exists")
         else:
             warnings.warn("training log file already exists. overwriting {}".format(self.path))
     self._current_file = open(self.path, "w+")
     names = ["N."]
     names += [x.get_name() for x in self.looper.losses.losses]
     line = "\t".join(names) + "\n"
     self._current_file.write(line)
     self._current_file.flush()
示例#13
0
 def forward(self,
             x,
             mask=None,
             _do_cosnorm=False,
             _retcosnorm=False,
             _no_mask_log=False):
     if self.mode == "cat":  # need to split up input
         basex = x[:, :self.base.vecdim]
         mergx = x[:, self.base.vecdim:]
         # TODO: not all wordlinouts have .vecdim
     elif self.mode == "sum":
         basex, mergx = x, x
     else:
         raise q.SumTingWongException()
     baseres = self.base(basex,
                         mask=mask,
                         _do_cosnorm=False,
                         _retcosnorm=_retcosnorm or self.cosnorm
                         or _do_cosnorm,
                         _no_mask_log=_no_mask_log)
     mergres = self.merg(mergx,
                         mask=mask,
                         _do_cosnorm=False,
                         _retcosnorm=_retcosnorm or self.cosnorm
                         or _do_cosnorm,
                         _no_mask_log=_no_mask_log)
     if _retcosnorm or self.cosnorm or _do_cosnorm:
         baseres, basecosnorm = baseres
         mergres, mergcosnorm = mergres
         cosnorm = basecosnorm + mergcosnorm
     res = baseres + mergres
     if _retcosnorm:
         return res, cosnorm
     if self.cosnorm or _do_cosnorm:
         res = res / torch.clamp(torch.norm(x, 2, 1).unsqueeze(1), min=EPS)
         res = res / torch.clamp(cosnorm, min=EPS).pow(1. / 2)
     return res
示例#14
0
def iscuda(x):
    if isinstance(x, torch.nn.Module):
        params = list(x.parameters())
        return params[0].is_cuda
    else:
        raise q.SumTingWongException("unsupported type")
示例#15
0
    def _forward(self,
                 scores,
                 gold,
                 mask=None):  # (batsize, numvoc), idx^(batsize,)
        # scores = scores - scores.min()
        goldscores = torch.gather(scores, 1, gold.unsqueeze(1)).squeeze()

        if mask is not None and mask[0, 1] > 1:
            mask = q.batchablesparse2densemask(mask)

        goldexamplemask = None

        if self.negmode == "random" or self.negmode == "negall":
            sampledist = scores.new(scores.size()).to(scores.device)
            sampledist.fill_(1.)
            sampledist.scatter_(1, gold.unsqueeze(1), 0)
            filtermask = scores > -np.infty
            if mask is not None:
                filtermask = filtermask & mask.byte()
            sampledist = sampledist * filtermask.float()
            sampledist_orig = sampledist
            if self.margin is not None and self.ignore_below_margin:
                cutoffs = goldscores - self.margin
                cutoffmask = scores > cutoffs.unsqueeze(1)
                sampledist = sampledist * cutoffmask.float()
            if (sampledist.sum(1) > 0).long().sum() < gold.size(0):
                # force to sample gold
                gold_onehot = torch.ByteTensor(sampledist.size()).to(
                    sampledist.device)
                gold_onehot.fill_(0)
                gold_onehot.scatter_(1, gold.unsqueeze(1), 1)
                goldexamplemask = (sampledist.sum(1) != 0)
                # addtosampledist = sampledist_orig * examplemask.float().unsqueeze(1)
                addtosampledist = gold_onehot * (~goldexamplemask).unsqueeze(1)
                sampledist.masked_fill_(addtosampledist, 1)
            if self.negmode == "random":
                sample = torch.multinomial(sampledist, 1)
                negscores = torch.gather(scores, 1, sample).squeeze()
            elif self.negmode == "negall":
                negscores = scores * sampledist
                numnegs = sampledist.sum(1)
        elif self.negmode == "best":
            # scores = scores * mask.float() if mask else scores
            scores = scores + torch.log(mask.float()) if mask else scores
            bestscores, best = torch.max(scores, 1)
            secondscores = scores + 0
            secondscores.scatter_(1, best.unsqueeze(1), 0)
            secondbestscores, secondbest = torch.max(secondscores, 1)
            switchmask = best == gold
            sample = secondbest * switchmask.long() + best * (
                1 + (-1) * switchmask.long())
            negscores = secondbestscores * switchmask.float() + bestscores * (
                1 - switchmask.float())
            goldexamplemask = sample.squeeze() != gold
            # raise NotImplemented("some issues regarding implementation not resolved")
        else:
            raise q.SumTingWongException("unknown mode: {}".format(
                self.negmode))

        if self.negmode == "best" or self.negmode == "random":
            loss = negscores - goldscores
            if self.margin is not None:
                loss = torch.clamp(self.margin + loss, min=0)
            if goldexamplemask is not None:
                loss = goldexamplemask.float() * loss
        elif self.negmode == "negall":
            # negscores are 2D
            loss = negscores - goldscores.unsqueeze(1)
            if self.margin is not None:
                loss = torch.clamp(self.margin + loss, min=0)
            loss = loss * sampledist
            loss = loss.sum(1)
            if self._average_negall:
                loss = loss / numnegs
            if goldexamplemask is not None:
                loss = loss * goldexamplemask.float()

        ignoremask = self._get_ignore_mask(gold)
        if ignoremask is not None:
            loss = loss * ignoremask.float()

        return loss, ignoremask
示例#16
0
 def on_end_epoch(self, owner, **kw):
     if not isinstance(owner, trainer):
         raise q.SumTingWongException("can only be hooked to a trainer")
     epoch = owner.current_epoch
     maxepochs = owner.epochs
     self.do_epoch(epoch, maxepochs)
示例#17
0
 def merge(self, x, mode="sum"):
     x.cosnorm = self.cosnorm
     if not self.D == x.D:
         raise q.SumTingWongException()
     return MergedWordLinout(self, x, mode=mode)
示例#18
0
def load_jsons(datap="../../../datasets/lcquad/newdata.json",
               relp="../../../datasets/lcquad/nrelations.json",
               mode="flat"):
    tt = q.ticktock("data loader")
    tt.tick("loading jsons")

    data = json.load(open(datap))
    rels = json.load(open(relp))

    tt.tock("jsons loaded")

    tt.tick("extracting data")
    questions = []
    goldchains = []
    badchains = []
    for dataitem in data:
        questions.append(dataitem["parsed-data"]["corrected_question"])
        goldchain = []
        for x in dataitem["parsed-data"]["path_id"]:
            goldchain += [x[0], int(x[1:])]
        goldchains.append(goldchain)
        badchainses = []
        goldfound = False
        for badchain in dataitem["uri"]["hop-1-properties"] + dataitem["uri"][
                "hop-2-properties"]:
            if goldchain == badchain:
                goldfound = True
            else:
                if len(badchain) == 2:
                    badchain += [-1, -1]
                badchainses.append(badchain)
        badchains.append(badchainses)

    tt.tock("extracted data")

    tt.msg("mode: {}".format(mode))

    if mode == "flat":
        tt.tick("flattening")

        def flatten_chain(chainspec):
            flatchainspec = []
            for x in chainspec:
                if x in ("+", "-"):
                    flatchainspec.append(x)
                elif x > -1:
                    relwords = rels[str(x)]
                    flatchainspec += relwords
                elif x == -1:
                    pass
                else:
                    raise q.SumTingWongException("unexpected symbol in chain")
            return " ".join(flatchainspec)

        goldchainids = []
        badchainsids = []

        uniquechainids = {}

        qsm = q.StringMatrix()
        csm = q.StringMatrix()
        csm.tokenize = lambda x: x.lower().strip().split()

        def get_ensure_chainid(flatchain):
            if flatchain not in uniquechainids:
                uniquechainids[flatchain] = len(uniquechainids)
                csm.add(flatchain)
                assert (len(csm) == len(uniquechainids))
            return uniquechainids[flatchain]

        eid = 0
        numchains = 0
        for question, goldchain, badchainses in zip(questions, goldchains,
                                                    badchains):
            qsm.add(question)
            # flatten gold chain
            flatgoldchain = flatten_chain(goldchain)
            chainid = get_ensure_chainid(flatgoldchain)
            goldchainids.append(chainid)
            badchainsids.append([])
            numchains += 1
            for badchain in badchainses:
                flatbadchain = flatten_chain(badchain)
                chainid = get_ensure_chainid(flatbadchain)
                badchainsids[eid].append(chainid)
                numchains += 1
            eid += 1
            tt.live("{}".format(eid))

        assert (len(badchainsids) == len(questions))
        tt.stoplive()
        tt.msg("{} unique chains from {} total".format(len(csm), numchains))
        qsm.finalize()
        csm.finalize()
        tt.tock("flattened")
        csm.tokenize = None
        return qsm, csm, goldchainids, badchainsids
    else:
        raise q.SumTingWongException("unsupported mode: {}".format(mode))
示例#19
0
 def logc(self, x, looper, logfilename):
     """ keep logging state of "x" based on "eventemitter"'s events and store in "logfilename"
         smart method --> dispatches according to type of x
     """
     raise q.SumTingWongException()