Ejemplo n.º 1
0
def eval_loop(model, dataloader, device=torch.device("cpu")):
    tto = q.ticktock("testing")
    tto.tick("testing")
    tt = q.ticktock("-")
    totaltestbats = len(dataloader)
    model.eval()
    epoch_reset(model)
    outs = []
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            batch = (batch, ) if not q.issequence(batch) else batch
            batch = q.recmap(
                batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)

            batch_reset(model)
            modelouts = model(*batch)

            tt.live("eval - [{}/{}]".format(i + 1, totaltestbats))
            if not q.issequence(modelouts):
                modelouts = (modelouts, )
            if len(outs) == 0:
                outs = [[] for e in modelouts]
            for out_e, mout_e in zip(outs, modelouts):
                out_e.append(mout_e)
    ttmsg = "eval done"
    tt.stoplive()
    tt.tock(ttmsg)
    tto.tock("tested")
    ret = [torch.cat(out_e, 0) for out_e in outs]
    return ret
Ejemplo n.º 2
0
    def forward(self, x, mask=None):
        fwd_ret = self.layer_fwd(x, mask=mask)
        rev_ret = self.layer_rev(x, mask=mask)

        merge_fn = (lambda a, b: torch.cat([a, b], -1)
                    ) if self.mode == "cat" else (lambda a, b: a + b)

        if not q.issequence(fwd_ret):
            fwd_ret = [fwd_ret]
        if not q.issequence(rev_ret):
            rev_ret = [rev_ret]
        ret = tuple()
        if self._return_final:
            ret += (merge_fn(fwd_ret[0], rev_ret[0]), )
            fwd_ret = fwd_ret[1:]
            rev_ret = rev_ret[1:]
        if self._return_all:
            ret += (merge_fn(fwd_ret[0], rev_ret[0]), )
        if self._return_mask:
            ret += (mask, )

        if len(ret) == 1:
            return ret[0]
        elif len(ret) == 0:
            print("no output specified")
            return
        else:
            return ret
Ejemplo n.º 3
0
 def __init__(self, fn, register_params=None, register_modules=None):
     super(Lambda, self).__init__()
     self.fn = fn
     # optionally registers passed modules and params
     if register_modules is not None:
         if not q.issequence(register_modules):
             register_modules = [register_modules]
         self.extra_modules = q.ModuleList(register_modules)
     if register_params is not None:
         if not q.issequence(register_params):
             register_params = [register_params]
         self.extra_params = nn.ParameterList(register_params)
Ejemplo n.º 4
0
    def forward(self, x, mask=None):
        x = self.dropout(x)

        if mask is not None:
            _x = torch.nn.utils.rnn.pack_padded_sequence(x,
                                                         mask.sum(-1),
                                                         batch_first=True,
                                                         enforce_sorted=False)
        else:
            _x = x

        _outputs, hidden = self.rnn(_x)

        if mask is not None:
            y, _ = torch.nn.utils.rnn.pad_packed_sequence(_outputs,
                                                          batch_first=True)
        else:
            y = _outputs

        hidden = (hidden, ) if not q.issequence(hidden) else hidden
        hiddens = []
        for _hidden in hidden:
            i = 0
            _hiddens = tuple()
            while i < _hidden.size(0):
                if self.bidir is True:
                    _h = torch.cat([_hidden[i], _hidden[i + 1]], -1)
                    i += 2
                else:
                    _h = _hidden[i]
                    i += 1
                _hiddens = _hiddens + (_h, )
            hiddens.append(_hiddens)
        hiddens = tuple(zip(*hiddens))
        return y, hiddens
Ejemplo n.º 5
0
 def forward(self, *x, **kw):
     y_l = x
     args, kwargs = None, None
     argmapped = False
     for layer in self.layers:
         if argmapped:
             rargs = args
             rkw = {}
             rkw.update(kw)
             rkw.update(kwargs)
         else:
             rargs = y_l
             rkw = kw
         if isinstance(layer, q.argmap):
             args, kwargs = layer(rargs, rkw, self._saved_slots)
             argmapped = True
         elif isinstance(layer, argsave):
             globols = layer(rargs, rkw, self._saved_slots)
         else:
             y_l = layer(*rargs, **rkw)
             argmapped = False
         if not q.issequence(y_l) and not argmapped:
             y_l = tuple([y_l])
     if argmapped:
         ret = args
     else:
         ret = y_l
     if len(ret) == 1:
         ret = ret[0]
     return ret
Ejemplo n.º 6
0
    def get_entity_property(self, entities, property, language=None):
        if not q.issequence(entities):
            entities = [entities]
        entities = [fbfy(entity) for entity in entities]
        propertychain = [fbfy(p) for p in property.strip().split()]
        propchain = ""
        prevvar = "?s"
        varcount = 0
        for prop in propertychain:
            newvar = "?var{}".format(varcount)
            varcount += 1
            propchain += "{} {} {} .\n".format(prevvar, prop, newvar)
            prevvar = newvar
        propchain = propchain.replace(prevvar, "?o")

        query = """SELECT ?s ?o WHERE {{
                        {}
                        VALUES ?s {{ {} }}
                        {}
                    }}""".format(
            propchain,
            " ".join(entities),
            "FILTER (lang(?o) = '{}')".format(language) if language is not None else "")
        ret = {}
        res = self._exec_query(query)
        results = res["results"]["bindings"]
        for result in results:
            s = unfbfy(result["s"]["value"])
            if s not in ret:
                ret[s] = set()
            val = result["o"]["value"]
            if language is None:
                val = unfbfy(val)
            ret[s].add(val)
        return ret
Ejemplo n.º 7
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            xlmrstates = self.xlmr.extract_features(inptensor)
            inpenc = xlmrstates
            final_enc = xlmrstates[:, 0, :]
            for i in range(len(self.enc_to_dec)):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc)
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_enc
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Ejemplo n.º 8
0
 def __init__(self, size_average=True, ignore_index=None, **kw):
     super(DiscreteLoss, self).__init__(**kw)
     if ignore_index is not None:
         if not q.issequence(ignore_index):
             self.ignore_indices = [ignore_index]
     else:
         self.ignore_indices = None
     self.size_average = size_average
Ejemplo n.º 9
0
    def forward(self, x_t, ctx=None, ctx_mask=None, **kw):
        if ctx is None:
            ctx, ctx_mask = self._saved_ctx, self._saved_ctx_mask
        assert (ctx is not None)

        # if isinstance(self.out, q.rnn.AutoMaskedOut):
        #     self.out.update(x_t)
        self.out.update(x_t)

        embs = self.emb(x_t)        # embed input tokens
        if q.issequence(embs) and len(embs) == 2:   # unpack if necessary
            embs, mask = embs

        if self.feed_att:
            if self._outvec_tm1 is None:
                assert (self.outvec_t0 is not None)   #"h_hat_0 must be set when feed_att=True"
                self._outvec_tm1 = self.outvec_t0
            core_inp = torch.cat([embs, self._outvec_tm1], 1)     # append previous attention summary
        else:
            core_inp = embs

        core_out = self.core(core_inp)  # feed through rnn

        # do normal attention over input
        alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx)  # do attention

        # do attention over decoded sequence
        if self.selfatt is not None and self.prev_coreouts is not None:
            selfalphas, selfsummaries, selfscores = self.selfatt(core_out, self.prev_coreouts)  # do self-attention
        else:
            selfalphas, selfsummaries, selfscores = None, None, None
        # TODO ??? use self-attention summaries for output generation too?

        out_vec = self.merge(core_out, summaries, core_inp)
        out_vec = self.dropout(out_vec)
        self._outvec_tm1 = out_vec      # store outvec (this is how Luong, 2015 does it)

        # save coreouts
        if self.selfatt is not None and self.prev_coreouts is not None:
            self.prev_coreouts = torch.cat([self.prev_coreouts, core_out.unsqueeze(1)], 1)
        else:
            self.prev_coreouts = core_out.unsqueeze(1)      # introduce a sequence dimension

        ret = tuple()
        if self.out is None:
            ret += (out_vec,)
        else:
            _out_vec = self.out(out_vec, scores=scores, selfscores=selfscores)
            ret += (_out_vec,)

        # other returns
        if self.return_alphas:
            ret += (alphas,)
        if self.return_scores:
            ret += (scores,)
        if self.return_other:
            ret += (embs, core_out, summaries)
        return ret[0] if len(ret) == 1 else ret
Ejemplo n.º 10
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_enc = self.inp_enc(inpembs, mask)
            final_enc = final_enc.view(final_enc.size(0), -1).contiguous()
            final_enc = self.enc_to_dec(final_enc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            mstate.prev_summ = torch.zeros_like(ctx[:, 0])
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            x.stored_attentions = torch.cat(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 1)

        return outs[0], x
Ejemplo n.º 11
0
 def forward(
     self,
     x,
     mask=None,
     init_states=None,
     reverse=False
 ):  # (batsize, seqlen, indim), (batsize, seqlen), [(batsize, hdim)]
     batsize = x.size(0)
     if init_states is not None:
         if not q.issequence(init_states):
             init_states = (init_states, )
         self.cell.set_init_states(*init_states)
     self.cell.reset_state()
     mask = mask if mask is not None else x.mask if hasattr(
         x, "mask") else None
     y_list = []
     y_tm1 = None
     y_t = None
     i = x.size(1)
     while i > 0:
         t = i - 1 if reverse else x.size(1) - i
         mask_t = mask[:, t].unsqueeze(1) if mask is not None else None
         x_t = x[:, t]
         cellout = self.cell(x_t, mask_t=mask_t, t=t)
         y_t = cellout
         # mask
         # if mask_t is not None:  # moved to cells (recBN is affected here)
         #     if y_tm1 is None:
         #         y_tm1 = q.var(torch.zeros(y_t.size())).cuda(crit=y_t).v
         #         if x.is_cuda: y_tm1 = y_tm1.cuda()
         #     y_t = y_t * mask_t + y_tm1 * (1 - mask_t)
         #     y_tm1 = y_t
         if self._return_all:
             y_list.append(y_t)
         i -= 1
     ret = tuple()
     if self._return_final:
         ret += (y_t, )
     if self._return_all:
         if reverse: y_list.reverse()
         y = torch.stack(y_list, 1)
         ret += (y, )
     if self._return_mask:
         ret += (mask, )
     if len(ret) == 1:
         return ret[0]
     elif len(ret) == 0:
         print("no output specified")
         return
     else:
         return ret
Ejemplo n.º 12
0
 def get_ignore_mask(gold, ignore_indices):
     if ignore_indices is not None and not q.issequence(ignore_indices):
         ignore_indices = [ignore_indices]
     mask = None  # (batsize,)
     if ignore_indices is not None:
         for ignore in ignore_indices:
             mask_i = (gold != ignore)  # zero for ignored ones
             if mask is None:
                 mask = mask_i
             else:
                 mask = mask & mask_i
     if mask is None:
         mask = torch.ones_like(gold).byte()
     return mask
Ejemplo n.º 13
0
    def forward(self, model_outs, gold, **kw):
        if q.issequence(model_outs):
            x = model_outs[self.which]
        else:
            assert (self.which == 0)
            x = model_outs

        if self.reduction in ["elementwise_mean", "mean"]:
            ret = x.mean()
        elif self.reduction == "sum":
            ret = x.sum()
        else:
            ret = x
        return ret
Ejemplo n.º 14
0
    def __call__(self, pred, gold, _numex=None, **kw):
        l = self.loss(pred, gold, **kw)

        if _numex is None:
            _numex = pred.size(0) if not q.issequence(pred) else pred[0].size(
                0)
        if isinstance(l, tuple) and len(l) == 2:  # loss returns numex too
            _numex = l[1]
            l = l[0]
        if isinstance(l, torch.Tensor):
            lp = l.item()
        else:
            lp = l
        self.epoch_agg_values.append(lp)
        self.epoch_agg_sizes.append(_numex)
        return l
Ejemplo n.º 15
0
 def forward(self, *x):  # TODO: multiple inputs and outputs
     x = [xe.contiguous() for xe in x]
     x0 = x[0]
     batsize, seqlen = x0.size(0), x0.size(1)
     i = [xe.view(batsize * seqlen, *xe.size()[2:]) for xe in x]
     y = self.block(*i)
     if not q.issequence(y):
         y = (y, )
     yo = []
     for ye in y:
         ye = ye.view(batsize, seqlen, *ye.size()[1:])
         yo.append(ye)
     if len(yo) == 1:
         return yo[0]
     else:
         return tuple(yo)
Ejemplo n.º 16
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            init_states = []
            for i in range(len(final_encs)):
                init_states.append(self.enc_to_dec[i](final_encs[i][0]))
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state
        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)
        return outs[0], x
Ejemplo n.º 17
0
    def forward(self, x_t, ctx=None, ctx_mask=None, **kw):
        assert (ctx is not None)

        embs = self.emb(x_t)
        if q.issequence(embs) and len(embs) == 2:
            embs, mask = embs

        if self.feed_att:
            if self._outvec_tm1 is None:
                assert (self.outvec_t0
                        is not None)  #"h_hat_0 must be set when feed_att=True"
                self._outvec_tm1 = self.outvec_t0
            core_inp = torch.cat([embs, self._outvec_tm1], 1)
        else:
            core_inp = embs

        prev_pushpop = self.get_pushpop_from(x_t)  # THIS LINE IS ADDED

        core_out = self.core(core_inp)

        alphas, summaries, scores = self.att(
            core_out,
            ctx,
            ctx_mask=ctx_mask,
            values=ctx,
            prev_pushpop=prev_pushpop)  # THIS LINE IS CHANGED
        out_vec = self.merge(core_out, summaries, core_inp)
        out_vec = self.dropout(out_vec)
        self._outvec_tm1 = out_vec  # store outvec

        ret = tuple()
        if self.out is None:
            ret += (out_vec, )
        else:
            _out_vec = self.out(out_vec)
            ret += (_out_vec, )

        if self.return_alphas:
            ret += (alphas, )
        if self.return_scores:
            ret += (scores, )
        if self.return_other:
            ret += (embs, core_out, summaries)
        return ret[0] if len(ret) == 1 else ret
Ejemplo n.º 18
0
    def forward(self, x_t, ctx=None, ctx_mask=None, **kw):
        assert (ctx is not None)

        if isinstance(self.out, q.rnn.AutoMaskedOut):
            self.out.update(x_t)

        embs = self.emb(x_t)        # embed input tokens
        if q.issequence(embs) and len(embs) == 2:   # unpack if necessary
            embs, mask = embs

        if self.feed_att:
            if self._outvec_tm1 is None:
                assert (self.outvec_t0 is not None)   #"h_hat_0 must be set when feed_att=True"
                self._outvec_tm1 = self.outvec_t0
            core_inp = torch.cat([embs, self._outvec_tm1], 1)     # append previous attention summary
        else:
            core_inp = embs

        prev_pushpop = self.get_pushpop_from(x_t)           # THIS LINE IS ADDED

        core_out = self.core(core_inp, prev_pushpop=prev_pushpop)  # feed through rnn   # THIS LINE IS CHANGED

        alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx)  # do attention
        out_vec = self.merge(core_out, summaries, core_inp)
        out_vec = self.dropout(out_vec)
        self._outvec_tm1 = out_vec      # store outvec (this is how Luong, 2015 does it)

        ret = tuple()
        if self.out is None:
            ret += (out_vec,)
        else:
            _out_vec = self.out(out_vec)
            ret += (_out_vec,)

        # other returns
        if self.return_alphas:
            ret += (alphas,)
        if self.return_scores:
            ret += (scores,)
        if self.return_other:
            ret += (embs, core_out, summaries)
        return ret[0] if len(ret) == 1 else ret
Ejemplo n.º 19
0
    def forward(self, x: BasicStateBatch):
        if "ctx" not in x.batched_states:
            # encode input
            inptensor = x.batched_states["inp_tensor"]
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_enc = self.inp_enc(inpembs, mask)
            final_enc = final_enc.view(final_enc.size(0), -1).contiguous()
            final_enc = self.enc_to_dec(final_enc)
            x.batched_states["ctx"] = inpenc
            x.batched_states["ctx_mask"] = mask

        ctx = x.batched_states["ctx"]
        ctx_mask = x.batched_states["ctx_mask"]

        emb = self.out_emb(x.batched_states["prev_token"])

        if "rnn" not in x.batched_states:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            x.batched_states["rnn"] = init_rnn_state

        # DONE: concat previous attention summary to emb
        if "prev_summ" not in x.batched_states:
            x.batched_states["prev_summ"] = torch.zeros_like(ctx[:, 0])
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, x.batched_states["prev_summ"]], 1)
        enc = self.out_rnn(_emb, x.batched_states["rnn"])

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        x.batched_states["prev_summ"] = summ
        enc = torch.cat([enc, summ], -1)

        outs = self.out_lin(enc, x, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        return outs[0], x
Ejemplo n.º 20
0
    def forward(self, x_t, ctx=None, ctx_mask=None, **kw):
        if ctx is None:
            ctx, ctx_mask = self._saved_ctx, self._saved_ctx_mask
        assert (ctx is not None)

        if self.out is not None and hasattr(self.out, "update"):
            self.out.update(x_t)        # update output layer with current input

        embs = self.emb(x_t)        # embed input tokens
        if q.issequence(embs) and len(embs) == 2:   # unpack if necessary
            embs, mask = embs

        if self.feed_att:
            if self._outvec_tm1 is None:
                assert (self.outvec_t0 is not None)   #"h_hat_0 must be set when feed_att=True"
                self._outvec_tm1 = self.outvec_t0
            core_inp = torch.cat([embs, self._outvec_tm1], 1)     # append previous attention summary
        else:
            core_inp = embs

        core_out = self.core(core_inp)  # feed through rnn

        alphas, summaries, scores = self.att(core_out, ctx, ctx_mask=ctx_mask, values=ctx)  # do attention
        out_vec = self.merge(core_out, summaries, core_inp)
        out_vec = self.dropout(out_vec)
        self._outvec_tm1 = out_vec      # store outvec (this is how Luong, 2015 does it)

        if self.out is None:
            ret_normal = out_vec
        else:
            if isinstance(self.out, PointerGeneratorOut):
                _out_vec = self.out(out_vec, scores=scores)
            else:
                _out_vec = self.out(out_vec)
            ret_normal = _out_vec

        l = locals()
        ret = tuple([l[k] for k in sum(self.returns, [])])
        return ret[0] if len(ret) == 1 else ret
Ejemplo n.º 21
0
def train_batch(batch=None,
                model=None,
                optim=None,
                losses=None,
                device=torch.device("cpu"),
                batch_number=-1,
                max_batches=0,
                current_epoch=0,
                max_epochs=0,
                on_start=tuple(),
                on_before_optim_step=tuple(),
                on_after_optim_step=tuple(),
                on_end=tuple()):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :return:
    """
    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
    numex = batch[0].size(0)

    if q.no_gold(losses):
        batch_in = batch
        gold = None
    else:
        batch_in = batch[:-1]
        gold = batch[-1]

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, gold, _numex=numex)
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    # penalties
    penalties = 0
    for loss_obj, trainloss in zip(losses, trainlosses):
        if isinstance(loss_obj.loss, q.loss.PenaltyGetter):
            penalties += trainloss
    cost = cost + penalties

    if torch.isnan(cost).any():
        print("Cost is NaN!")
        embed()

    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Ejemplo n.º 22
0
def train_batch_distill(batch=None,
                        model=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        batch_number=-1,
                        max_batches=0,
                        current_epoch=0,
                        max_epochs=0,
                        on_start=tuple(),
                        on_before_optim_step=tuple(),
                        on_after_optim_step=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Runs a single batch of SGD on provided batch and settings.
    :param _batch:  batch to run on
    :param model:   torch.nn.Module of the model
    :param optim:       torch optimizer
    :param losses:      list of losswrappers
    :param device:      device
    :param batch_number:    which batch
    :param max_batches:     total number of batches
    :param current_epoch:   current epoch
    :param max_epochs:      total number of epochs
    :param on_start:        collection of functions to call when starting training batch
    :param on_before_optim_step:    collection of functions for before optimization step is taken (gradclip)
    :param on_after_optim_step:     collection of functions for after optimization step is taken
    :param on_end:              collection of functions to call when batch is done
    :param mbase:           base model where to distill from. takes inputs and produces output distributions to match by student model. if goldgetter is specified, this is not used.
    :param goldgetter:      takes the gold and produces a softgold
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_batch, **kwargs)

    [e() for e in on_start]
    optim.zero_grad()
    model.train()

    batch = (batch, ) if not q.issequence(batch) else batch
    batch = q.recmap(
        batch, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)

    batch_in = batch[:-1]
    gold = batch[-1]

    # run batch_in through teacher model to get teacher output distributions
    if goldgetter is not None:
        softgold = goldgetter(gold)
    elif mbase is not None:
        mbase.eval()
        q.batch_reset(mbase)
        with torch.no_grad():
            softgold = mbase(*batch_in)
    else:
        raise q.SumTingWongException(
            "goldgetter and mbase can not both be None")

    q.batch_reset(model)
    modelouts = model(*batch_in)

    trainlosses = []
    for loss_obj in losses:
        loss_val = loss_obj(modelouts, (softgold, gold))
        loss_val = [loss_val] if not q.issequence(loss_val) else loss_val
        trainlosses.extend(loss_val)

    cost = trainlosses[0]
    cost.backward()

    [e() for e in on_before_optim_step]
    optim.step()
    [e() for e in on_after_optim_step]

    ttmsg = "train - Epoch {}/{} - [{}/{}]: {}".format(
        current_epoch + 1,
        max_epochs,
        batch_number + 1,
        max_batches,
        q.pp_epoch_losses(*losses),
    )

    [e() for e in on_end]
    return ttmsg
Ejemplo n.º 23
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

            if self.training and q.v(self.beta) < 1:    # sample one of the orders
                golds = x._gold_tensors
                goldsmask = (golds != 0).any(-1).float()
                numgolds = goldsmask.sum(-1)
                gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None]
                selector = gold_select_prob.multinomial(1)[:, 0]
                gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0]
                # interpolate with original gold
                original_gold = x.gold_tensor
                beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long()
                gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None])
                x.gold_tensor = gold_


        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        _emb = emb

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x
Ejemplo n.º 24
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        # ONR stuff: !!! assumes LISP style queries with parentheses as separate tokens and only parentheses opening and closing clauses
        stack_actions = torch.zeros_like(x.prev_actions)
        stack_actions += (x.prev_actions == self.open_id).long() * +1
        stack_actions += (x.prev_actions == self.close_id).long() * -1

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _ctx = mstate.prev_summ
        else:
            _ctx = torch.zeros(_emb.size(0), 0, device=_emb.device)

        enc, new_rnnstate = self.out_rnn(_emb, _ctx, stack_actions, mstate.rnnstate)

        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Ejemplo n.º 25
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):  # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        if "prevstates" not in mstate:
            _ctx = ctx
            _ctx_mask = ctx_mask
            mstate.prevstates = enc[:, None, :]
        else:
            _ctx = torch.cat([ctx, mstate.prevstates], 1)
            _ctx_mask = torch.cat([
                ctx_mask,
                torch.ones(mstate.prevstates.size(0),
                           mstate.prevstates.size(1),
                           dtype=ctx_mask.dtype,
                           device=ctx_mask.device)
            ], 1)
            mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]],
                                          1)

        alphas, summ, scores = self.att(enc, _ctx, _ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            atts = q.pad_tensors(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 2, 0)
            x.stored_attentions = torch.cat(atts, 1)

        return outs[0], x
Ejemplo n.º 26
0
def test_epoch(model=None,
               dataloader=None,
               losses=None,
               device=torch.device("cpu"),
               current_epoch=0,
               max_epochs=0,
               print_every_batch=False,
               on_start=tuple(),
               on_start_batch=tuple(),
               on_end_batch=tuple(),
               on_end=tuple()):
    """
    Performs a test epoch. If run=True, runs, otherwise returns partially filled function.
    :param model:
    :param dataloader:
    :param losses:
    :param device:
    :param current_epoch:
    :param max_epochs:
    :param on_start:
    :param on_start_batch:
    :param on_end_batch:
    :param on_end:
    :return:
    """
    tt = q.ticktock("-")
    model.eval()
    q.epoch_reset(model)
    [e() for e in on_start]
    with torch.no_grad():
        for loss_obj in losses:
            loss_obj.push_epoch_to_history()
            loss_obj.reset_agg()
            loss_obj.loss.to(device)
        for i, _batch in enumerate(dataloader):
            [e() for e in on_start_batch]

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = q.recmap(
                _batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)
            batch = _batch
            numex = batch[0].size(0)

            if q.no_gold(losses):
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            q.batch_reset(model)
            modelouts = model(*batch_in)

            testlosses = []
            for loss_obj in losses:
                loss_val = loss_obj(modelouts, gold, _numex=numex)
                loss_val = [loss_val
                            ] if not q.issequence(loss_val) else loss_val
                testlosses.extend(loss_val)

            ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format(
                current_epoch + 1, max_epochs, i + 1, len(dataloader),
                q.pp_epoch_losses(*losses))
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
            [e() for e in on_end_batch]
    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Ejemplo n.º 27
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        if not "outenc" in mstate:
            if self.training:
                outtensor = x.gold_tensor
                omask = outtensor != 0
                outembs = self.out_emb_vae(outtensor)
                finalenc, _ = self.out_enc(outembs, omask)
                finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1)        # max pool
                # reparam
                mu = self.out_mu(finalenc)
                logvar = self.out_logvar(finalenc)
                std = torch.exp(.5*logvar)
                eps = torch.randn_like(std)
                outenc = mu + eps * std
                mstate.outenc = outenc
                kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
                kld = torch.sum(kld.clamp_min(self.minkl), -1)
                mstate.kld = kld

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        if self.training:
            outenc = mstate.outenc
            # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0]
        else:
            outenc = torch.randn(emb.size(0), self.zdim, device=emb.device)
        _emb = torch.cat([emb, outenc], 1)

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x