Пример #1
0
    def __call__(self, iter=None):
        iter = self._iter if iter is None else iter
        self.generator.eval()
        with torch.no_grad():
            # collect generated images
            generated = []
            self.tt.tick("running generator")
            for i, batch in enumerate(self.gendata):
                batch = (batch, ) if not q.issequence(batch) else batch
                batch = [
                    torch.tensor(batch_e).to(self.device) for batch_e in batch
                ]
                _gen = self.generator(*batch).detach().cpu()
                _gen = _gen[0] if q.issequence(_gen) else _gen
                generated.append(_gen)
                self.tt.live("{}/{}".format(i, len(self.gendata)))
            batsize = max(map(len, generated))
            generated = torch.cat(generated, 0)
            self.tt.tock("generated data")

            gen_loaded = q.dataload(generated,
                                    batch_size=batsize,
                                    shuffle=False)
            rets = [iter]
            for scorer in self.scorers:
                ret = scorer(gen_loaded)
                if ret is not None:
                    rets.append(ret)
            if self.logger is not None:
                self.logger.liner_write("validator-{}.txt".format(self.name),
                                        " ".join(map(str, rets)))
            self._iter += 1
        return " ".join(map(str, rets[1:]))
Пример #2
0
 def forward(self, ldata, rdata):
     ldata = ldata if q.issequence(ldata) else (ldata, )
     rdata = rdata if q.issequence(rdata) else (rdata, )
     # q.embed()
     lvecs = self.lmodel(*ldata)  # 2D
     rvecs = self.rmodel(*rdata)  # 2D
     psim = self.sim(lvecs, rvecs)  # 1D:(batsize,)
     return psim
Пример #3
0
 def __init__(self, scoremodel, eids, ldata, rdata, eid2rid_gold,
              eid2rid_neg):
     self.scoremodel = scoremodel
     self.eids = eids
     self.ldata = ldata if q.issequence(ldata) else (ldata,
                                                     )  # already shuffled
     self.rdata = rdata if q.issequence(rdata) else (
         rdata, )  # indexed by eid space
     self.eid2rid_neg = eid2rid_neg  # indexed by eid space
     self.eid2rid_gold = eid2rid_gold  # indexed by eid space
Пример #4
0
    def forward(self, ldata, posrdata, negrdata):
        ldata = ldata if q.issequence(ldata) else (ldata, )
        posrdata = posrdata if q.issequence(posrdata) else (posrdata, )
        negrdata = negrdata if q.issequence(negrdata) else (negrdata, )
        lvecs = self.lmodel(*ldata)  # 2D
        rvecs = self.rmodel(*posrdata)  # 2D
        nrvecs = self.rmodel(*negrdata)
        psim = self.sim(lvecs, rvecs)  # 1D:(batsize,)
        nsim = self.sim(lvecs, nrvecs)

        diffs = psim - nsim
        zeros = q.var(torch.zeros_like(diffs.data)).cuda(diffs).v
        losses = torch.max(zeros, self.margin - diffs)

        return losses
Пример #5
0
 def __init__(self, size_average=True, ignore_index=None, **kw):
     super(DiscreteLoss, self).__init__(size_average=size_average, **kw)
     if ignore_index is not None:
         if not q.issequence(ignore_index):
             self.ignore_indices = [ignore_index]
     else:
         self.ignore_indices = None
Пример #6
0
def dataset(*x):
    if q.issequence(x):
        assert(len(x) == 1)
        x = x[0]
    # if not issequence(x):
    #     x = (x,)
    return tensor_dataset(*x)
Пример #7
0
    def evalloop(self):
        self.tt.tick("testing")
        tt = ticktock("-")
        totaltestbats = len(self.dataloader)
        self.model.eval()
        outs = []
        with torch.no_grad():
            for i, batch in enumerate(self.dataloader):
                batch = (batch, ) if not q.issequence(batch) else batch
                batch = [batch_e.to(self._device) for batch_e in batch]
                if self.transform_batch_inp is not None:
                    batch = self.transform_batch_inp(*batch)

                batch_reset(self.model)
                modelouts = self.model(*batch)

                if self.transform_batch_out is not None:
                    modelouts = self.transform_batch_out(modelouts)

                tt.live("eval - [{}/{}]".format(i + 1, totaltestbats))
                outs.append(modelouts)
        ttmsg = "eval done"
        tt.stoplive()
        tt.tock(ttmsg)
        self.tt.tock("tested")
        out = torch.cat(outs, 0)
        return out
Пример #8
0
    def testloop(self, epoch=None):
        if epoch is None:
            self.tt.tick("testing")
        tt = ticktock("-")
        self.model.eval()
        self.do_callbacks(self.START_TEST)
        self.losses.push_and_reset()
        totalbats = len(self.dataloader)
        for i, _batch in enumerate(self.dataloader):
            self.do_callbacks(self.START_BATCH)

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = [batch_e.to(self._device) for batch_e in _batch]
            if self.transform_batch_inp is not None:
                batch = self.transform_batch_inp(*_batch)
            else:
                batch = _batch

            if self.no_gold:
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            batch_reset(self.model)
            modelouts = self.model(*batch_in)

            modelout2loss = modelouts
            if self.transform_batch_out is not None:
                modelout2loss = self.transform_batch_out(modelouts)
            gold = batch[-1]
            if self.transform_batch_gold is not None:
                gold = self.transform_batch_gold(gold)

            losses = self.losses(modelout2loss, gold)

            epochmsg = ""
            if epoch is not None:
                curepoch, maxepoch = epoch
                epochmsg = "Epoch {}/{} -".format(curepoch, maxepoch)

            tt.live("{} - {}[{}/{}]: {}".format(self._name, epochmsg, i + 1,
                                                totalbats, self.losses.pp()))
            self.do_callbacks(self.END_BATCH)
        # losses = self.losses.get_agg_errors()
        tt.stoplive()
        ttmsg = "{}: {}" \
            .format(
            self._name,
            self.losses.pp()
        )
        self.do_callbacks(self.END_TEST)
        if epoch is None:
            tt.tock(ttmsg)
            self.tt.tock("tested")
        return ttmsg
Пример #9
0
 def __init__(self, size_average=True, ignore_index=0):
     super(OldOldSeqAccuracy, self).__init__()
     self.size_average = size_average
     if ignore_index is not None:
         if not q.issequence(ignore_index):
             ignore_index = [ignore_index]
         self.ignore_index = ignore_index
     else:
         self.ignore_index = None
     self.EPS = 1e-6
Пример #10
0
 def __call__(self, gendata):
     ret = []
     for batch in gendata:
         if not q.issequence(batch):
             batch = (batch, )
         ret.append(batch)
     ret = [[batch_e.cpu() for batch_e in batch] for batch in ret]
     ret = [torch.cat(ret_i, 0).numpy() for ret_i in zip(*ret)]
     tosave = dict(zip(map(str, range(len(ret))), ret))
     if self.logger is not None:
         np.savez(os.path.join(self.logger.p, self.p), **tosave)
Пример #11
0
 def __call__(self, pred, gold, **kw):
     l = self.loss(pred, gold, **kw)
     numex = pred.size(0) if not q.issequence(pred) else pred[0].size(0)
     if isinstance(l, tuple) and len(l) == 2:  # loss returns numex too
         numex = l[1]
         l = l[0]
     if isinstance(l, torch.Tensor):
         lp = l.item()
     else:
         lp = l
     self.update_agg(lp, numex)
     return l
Пример #12
0
 def __init__(self,
              weight=None,
              size_average=True,
              time_average=True,
              ignore_index=0):
     if ignore_index is not None:
         if not q.issequence(ignore_index):
             ignore_index = [ignore_index]
     else:
         ignore_index = None
     super(SeqNLLLoss, self).__init__(weight=weight,
                                      size_average=size_average,
                                      ignore_index=ignore_index)
     self.time_average = time_average
Пример #13
0
    def __call__(self, data):
        """
        :param data:    dataloader
        :return:
        """
        # 1. get a numpy array from dataloader

        x = []
        for i, batch in enumerate(data):
            batch = (batch, ) if not q.issequence(batch) else batch
            assert (len(batch) == 1)
            x.append(batch[0])
        x = np.concatenate(x, axis=0)
        print(x.shape)
        # 2. use tf code above to get the scores
        means, vars = self.get_inception_score(x)
        return means, vars
Пример #14
0
    def __init__(self,
                 disc_trainer,
                 gen_trainer,
                 validators=None,
                 lr_decay=False):
        """
        Creates a GAN trainer given a gen_trainer and disc_trainer.
        both trainers already contain the model, optimizer and losses and implement updating and batching

        Takes a validator or a list of validators (with different validinters).
        """
        super(GANTrainer, self).__init__()
        self.disc_trainer = disc_trainer
        self.gen_trainer = gen_trainer
        if not q.issequence(validators) and validators is not None:
            validators = (validators, )
        self.validators = validators
        self.stop_training = False
        self.lr_decay = lr_decay
Пример #15
0
 def get_inception_outs(self, data):  # dataloader
     tt = q.ticktock("inception")
     tt.tick("running data through network")
     probses = []
     activationses = []
     for i, batch in enumerate(data):
         batch = (batch, ) if not q.issequence(batch) else batch
         batch = [
             torch.tensor(batch_e).to(self.device) for batch_e in batch
         ]
         probs, activations = self.inception(*batch)
         probs = torch.nn.functional.softmax(probs)
         probses.append(probs.detach())
         activationses.append(activations.detach())
         tt.live("{}/{}".format(i, len(data)))
     tt.stoplive()
     tt.tock("done")
     probses = torch.cat(probses, 0)
     activationses = torch.cat(activationses, 0)
     return probses.cpu().detach().numpy(), activationses.cpu().detach(
     ).numpy()
Пример #16
0
    def do_batch(self, _batch, i=-1):
        """
        performs a single batch of SGD on the provided batch
        with configured model, dataloader and optimizer
        """
        self.do_callbacks(self.START_BATCH)
        self.optim.zero_grad()
        self.model.train()
        # params = q.params_of(self.model)

        _batch = (_batch, ) if not q.issequence(_batch) else _batch
        _batch = [batch_e.to(self._device) for batch_e in _batch]
        if self.transform_batch_inp is not None:
            batch = self.transform_batch_inp(*_batch)
        else:
            batch = _batch

        if self.no_gold:
            batch_in = batch
            gold = None
        else:
            batch_in = batch[:-1]
            gold = batch[-1]

        batch_reset(self.model)
        modelouts = self.model(*batch_in)

        modelout2loss = modelouts
        if self.transform_batch_out is not None:
            modelout2loss = self.transform_batch_out(modelouts)

        if self.transform_batch_gold is not None:
            gold = self.transform_batch_gold(gold)
        trainlosses = self.losses(modelout2loss, gold)

        # TODO: put in penalty mechanism

        cost = trainlosses[0]

        if torch.isnan(cost).any():
            print("Cost is NaN!")
            q.embed()

        penalties = 0.
        penalty_values = {}
        for penalty_module in q.gather_penalties(self.model):
            pen = penalty_module.get_penalty()
            penalties = penalties + pen
            if type(penalty_module) not in penalty_values:
                penalty_values[type(penalty_module)] = 0.
            penalty_values[type(penalty_module)] += pen.detach().item()

        cost = cost + penalties

        cost.backward()

        self.do_callbacks(self.BEFORE_OPTIM_STEP)
        self.optim.step()
        self.do_callbacks(self.AFTER_OPTIM_STEP)

        ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
            self.current_epoch + 1,
            self.max_epochs,
            i + 1,
            len(self.dataloader),
            self.losses.pp(),
        )
        if len(penalty_values) > 0:
            ttmsg += " " + " ".join([
                "+{}={:.4f}".format(k.__pp_name__, v)
                for k, v in penalty_values.items()
            ])
        self.do_callbacks(self.END_BATCH)
        return ttmsg
Пример #17
0
    def forward(self, x, whereat, ctx=None, ctx_mask=None):
        """
        :param x:       (batsize, out_seqlen) integer ids of the words in the partial tree so far
        :param whereat: (batsize,) integers, after which position index in x to insert predicted token
        :param ctx:     (batsize, in_seqlen, dim) the context to be used with attention
        :return:
        """
        if isinstance(self.out, q.AutoMaskedOut):
            self.out.update(x, whereat)

        mask = None
        embs = self.emb(x)
        if q.issequence(embs) and len(embs) == 2:
            embs, mask = embs

        if mask is None:
            mask = torch.ones_like(x[:, :, 0])
        y = self.encoder(x, mask=mask)
        whereat_sel = whereat.view(whereat.size(0), 1,
                                   1).repeat(1, 1,
                                             y.size(2) // 2)
        z_fwd = y[:, :, :y.size(2) // 2].gather(1, whereat_sel).squeeze(
            1)  # (batsize, dim)  # not completely correct but fine
        z_rev = y[:, :, y.size(2) // 2:].gather(1, whereat_sel + 1).squeeze(1)
        z = torch.cat([z_fwd, z_rev], 1)

        core_inp = embs

        if self.att is not None:
            assert (ctx is not None)
            if self.feed_att:
                if self._h_hat_tm1 is None:
                    assert (self.h_hat_0 is not None
                            )  #"h_hat_0 must be set when feed_att=True"
                    self._h_hat_tm1 = self.h_hat_0
                core_inp = torch.cat([core_inp, self._h_hat_tm1], 1)

        core_out = self.core(core_inp)

        alphas, summaries, scores = None, None, None
        out_vec = core_out
        if self.att is not None:
            alphas, summaries, scores = self.att(core_out,
                                                 ctx,
                                                 ctx_mask=ctx_mask,
                                                 values=ctx)
            out_vec = torch.cat([core_out, summaries], 1)
            out_vec = self.merge(
                out_vec) if self.merge is not None else out_vec
            self._h_hat_tm1 = out_vec

        ret = tuple()
        if not self.return_outvecs:
            out_vec = self.out(out_vec)
        ret += (out_vec, )

        if self.return_alphas:
            ret += (alphas, )
        if self.return_scores:
            ret += (scores, )
        if self.return_other:
            ret += (embs, core_out, summaries)
        return ret[0] if len(ret) == 1 else ret
Пример #18
0
 def _forward(self, model_outs, gold, **kw):
     if q.issequence(model_outs):
         return model_outs[self.which], None
     else:
         assert (self.which == 0)
         return model_outs, None