Exemplo n.º 1
0
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
            goldmaxlen = max(goldmaxlen, state.candtensors.size(-1))
        inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0)
        gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1,
                                     0)
        candtensors = q.pad_tensors([state.candtensors for state in data], 2,
                                    0)
        alignments = q.pad_tensors([state.alignments for state in data], 2, 0)
        alignment_entropies = q.pad_tensors(
            [state.alignment_entropies for state in data], 2, 0)

        for i, state in enumerate(data):
            state.inp_tensor = inp_tensors[i]
            state.gold_tensor = gold_tensors[i]
            state.candtensors = candtensors[i]
            state.alignments = alignments[i]
            state.alignment_entropies = alignment_entropies[i]
        ret = data[0].merge(data)
        return ret
Exemplo n.º 2
0
def collate_fn(x, pad_value_nl=0, pad_value_fl=0):
    y = list(zip(*x))
    assert (len(y) == 3)
    y[0] = torch.stack(q.pad_tensors(y[0], 0, pad_value_nl), 0)
    y[1] = torch.stack(q.pad_tensors(y[1], 0, pad_value_nl), 0)
    y[2] = torch.stack(q.pad_tensors(y[2], 0, pad_value_fl), 0)
    return tuple(y)
Exemplo n.º 3
0
def collate_fn(x, pad_value=0):
    y = list(zip(*x))
    assert (len(y) == 4)

    y[0] = torch.stack(q.pad_tensors(y[0], 0, pad_value), 0)
    y[1] = torch.stack(q.pad_tensors(y[1], 0, pad_value), 0)
    y[2] = torch.stack(q.pad_tensors(y[2], 0, False), 0)
    y[3] = torch.stack(q.pad_tensors(y[3], 0, pad_value), 0)

    return y
Exemplo n.º 4
0
def get_arrays_to_save(x:List[State]):
    inp_tensors = torch.cat(q.pad_tensors([xe.inp_tensor for xe in x], 1), 0)
    followed_actions = torch.cat(q.pad_tensors([xe.followed_actions for xe in x], 1), 0)
    gold_tensors = torch.cat(q.pad_tensors([xe.gold_tensor for xe in x], 1), 0)
    attentions = torch.cat(q.pad_tensors([xe.stored_attentions for xe in x], (1, 2)), 0)
    return {
        "inp_tensor": inp_tensors,
        "gold_tensor": gold_tensors,
        "followed_actions": followed_actions,
        "attentions": attentions
    }
Exemplo n.º 5
0
def pad_and_default_collate(x, pad_value=0):
    y = list(zip(*x))
    for i, yi in enumerate(y):
        if isinstance(yi[0], torch.LongTensor) and yi[0].dim() == 1:
            y[i] = q.pad_tensors(yi, 0, pad_value)
    x = list(zip(*y))
    ret = default_collate(x)
    return ret
Exemplo n.º 6
0
def autocollate(x, pad_value=0):
    y = list(zip(*x))
    for i, yi in enumerate(y):
        if isinstance(yi[0], torch.LongTensor) and yi[0].dim() == 1:
            y[i] = q.pad_tensors(yi, 0, pad_value)
    for i, yi in enumerate(y):
        if isinstance(yi[0], torch.Tensor):
            yi = [yij[None] for yij in yi]
            y[i] = torch.cat(yi, 0)
    return y
Exemplo n.º 7
0
    def build_data(self, examples: Iterable[dict], splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for example, split in zip(examples, splits):
            inp, out = " ".join(example["sentence"]), " ".join(example["gold"])
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(" ".join(example["gold"][:-1]))
            if not isinstance(gold_tree, Tree):
                assert (gold_tree is not None)
            gold_tensor, gold_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")

            candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], []
            candidate_align_entropies = []
            candidate_trees = []
            candidate_same = []
            for cand in example["candidates"]:
                cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]),
                                            None)
                if cand_tree is None:
                    cand_tree = Tree("@UNK@", [])
                assert (cand_tree is not None)
                cand_tensor, cand_tokens = self.query_encoder.convert(
                    " ".join(cand["tokens"]), return_what="tensor,tokens")
                candidate_tensors.append(cand_tensor)
                candidate_tokens.append(cand_tokens)
                candidate_align_tensors.append(torch.tensor(
                    cand["alignments"]))
                candidate_align_entropies.append(
                    torch.tensor(cand["align_entropies"]))
                candidate_trees.append(cand_tree)
                candidate_same.append(
                    are_equal_trees(cand_tree,
                                    gold_tree,
                                    orderless={"and", "or"},
                                    unktoken="@NOUNKTOKENHERE@"))

            candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0),
                                           0)
            candidate_align_tensor = torch.stack(
                q.pad_tensors(candidate_align_tensors, 0), 0)
            candidate_align_entropy = torch.stack(
                q.pad_tensors(candidate_align_entropies, 0), 0)
            candidate_same = torch.tensor(candidate_same)

            state = RankState(
                inp_tensor[None, :],
                gold_tensor[None, :],
                candidate_tensor[None, :, :],
                candidate_same[None, :],
                candidate_align_tensor[None, :],
                candidate_align_entropy[None, :],
                self.sentence_encoder.vocab,
                self.query_encoder.vocab,
            )
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, candidate_tensor.size(-1),
                             gold_tensor.size(-1))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out
    def forward(self,
                inpseqs: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None)
        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [
                add_descendants_ancestors(goldtree) for goldtree in goldtrees
            ]
            for i in range(len(trees)):
                trees[i].align = goldtrees[i]

        i = 0

        if self.training:
            trees = [assign_gold_actions(tree, mode=mode) for tree in trees]
        # choices = [deepcopy(trees)]
        numsteps = [0 for _ in range(len(trees))]
        treesizes = [0 for _ in range(len(trees))]
        seqlens = [0 for _ in range(len(trees))]
        allmetrics = {}
        allmetrics["loss"] = []
        allmetrics["ce"] = []
        allmetrics["elemrecall"] = []
        allmetrics["allrecall"] = []
        allmetrics["anyrecall"] = []
        allmetrics["lowestentropyrecall"] = []
        examplemasks = []
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            seq = []
            openmask = []
            stepgold = []
            for j, tree in enumerate(trees):
                if not all_terminated(tree):
                    numsteps[j] += 1
                treesizes[j] = tree_size(tree)

                fltoks, openmask_e, stepgold_e = extract_info(tree)
                seqlens[j] = len(fltoks)
                seq_e = self.seqenc.convert(fltoks, return_what="tensor")
                seq.append(seq_e)
                openmask.append(torch.tensor(openmask_e))
                if self.training:
                    stepgold_e_tensor = torch.zeros(
                        seq_e.size(0), self.seqenc.vocab.number_of_ids())
                    for j, stepgold_e_i in enumerate(stepgold_e):
                        for golde in stepgold_e_i:
                            stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1
                    stepgold.append(stepgold_e_tensor)
            seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(openmask, 0, False),
                                   0).to(device)
            if self.training:
                stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0),
                                       0).to(device)

            if self.training:
                examplemask = (stepgold != 0).any(-1).any(-1)
                examplemasks.append(examplemask)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            if self.training:
                # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions)
                ce = self.ce(probs, stepgold, mask=openmask)
                elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(
                    probs, stepgold, mask=openmask)
                allmetrics["loss"].append(ce)
                allmetrics["ce"].append(ce)
                allmetrics["elemrecall"].append(elemrecall)
                allmetrics["allrecall"].append(allrecall)
                allmetrics["anyrecall"].append(anyrecall)
                allmetrics["lowestentropyrecall"].append(lowestentropyrecall)

            #  get best predictions,
            _, best_actions = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            if self.training:
                newprobs = torch.softmax(probs,
                                         -1) * stepgold  # mask using gold
                newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(
                    1e-6)  # renormalize
                uniform = stepgold
                uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6)
                newprobs = newprobs * (1 - q.v(
                    self.uniformfactor)) + uniform * q.v(self.uniformfactor)

                noprobsmask = (newprobs != 0).any(-1, keepdim=True).float()
                zeroprobs = torch.zeros_like(newprobs)
                zeroprobs[:, :, 0] = 1
                newprobs = newprobs * noprobsmask + (1 -
                                                     noprobsmask) * zeroprobs

                sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\
                    .sample().view(newprobs.size(0), newprobs.size(1))
                taken_actions = sampled_gold_actions
            else:
                taken_actions = best_actions

            #  attach chosen actions and entropies to existing trees,
            for tree, actions_e, entropies_e in zip(trees, taken_actions,
                                                    entropies):
                actions_e = list(actions_e.detach().cpu().numpy())
                actions_e = [self.seqenc.vocab(xe) for xe in actions_e]
                entropies_e = list(entropies_e.detach().cpu().numpy())
                self.attach_info_to_tree(tree,
                                         _chosen_action=actions_e,
                                         _entropy=entropies_e)
            # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e)
            #          for seqe, openmaske, actione, entropies_e
            #          in zip(list(seq), list(openmask), list(taken_actions), list(entropies))]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    markmode = mode if not self.training else "train-" + mode
                    tree = mark_for_execution(tree,
                                              mode=markmode,
                                              entropylimit=self.entropylimit)
                    budget = [self.max_tree_size - tree_size(tree)]
                    if self.training:
                        tree, _ = convert_chosen_actions_from_str_to_node(tree)
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                    if self.training:
                        tree = assign_gold_actions(tree, mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        ret["seqlens"] = torch.tensor(seqlens).float()
        ret["treesizes"] = torch.tensor(treesizes).float()
        ret["numsteps"] = torch.tensor(numsteps).float()

        if self.training:
            assert (len(examplemasks) > 0)
            assert (len(allmetrics["loss"]) > 0)
            allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()}
            examplemasks = torch.stack(examplemasks, 1).float()
            _allmetrics = {}
            for k, v in allmetrics.items():
                _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(
                    1).clamp_min(1e-6)
            allmetrics = _allmetrics
            ret.update(allmetrics)

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
Exemplo n.º 9
0
    def forward(self,
                inpseqs: torch.Tensor = None,
                y_in: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=y_in)

        i = 0
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            tensors = []
            masks = []
            for tree in trees:
                fltoks, openmask = extract_info(tree, nogold=True)
                seq = self.seqenc.convert(fltoks, return_what="tensor")
                tensors.append(seq)
                masks.append(torch.tensor(openmask))
            seq = torch.stack(q.pad_tensors(tensors, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(masks, 0, False),
                                   0).to(device)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            #  get best predictions,
            _, best = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            #  convert to trees,
            trees = [
                tensors_to_tree(seqe,
                                openmask=openmaske,
                                actions=beste,
                                D=self.seqenc.vocab,
                                entropies=entropies_e)
                for seqe, openmaske, beste, entropies_e in zip(
                    list(seq), list(openmask), list(best), list(entropies))
            ]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    tree = mark_for_execution(tree, mode=mode)
                    budget = [self.max_tree_size - tree_size(tree)]
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
Exemplo n.º 10
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):  # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        if "prevstates" not in mstate:
            _ctx = ctx
            _ctx_mask = ctx_mask
            mstate.prevstates = enc[:, None, :]
        else:
            _ctx = torch.cat([ctx, mstate.prevstates], 1)
            _ctx_mask = torch.cat([
                ctx_mask,
                torch.ones(mstate.prevstates.size(0),
                           mstate.prevstates.size(1),
                           dtype=ctx_mask.dtype,
                           device=ctx_mask.device)
            ], 1)
            mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]],
                                          1)

        alphas, summ, scores = self.att(enc, _ctx, _ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            atts = q.pad_tensors(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 2, 0)
            x.stored_attentions = torch.cat(atts, 1)

        return outs[0], x