Exemplo n.º 1
0
def collate_fn(batch):
    if len(batch[0]) == 3:
        fc_feat, att_feat, trees = zip(*batch)
        fc_feat, att_feat = torch.from_numpy(np.stack(
            fc_feat, axis=0)), torch.from_numpy(np.stack(att_feat, axis=0))
        feats = (fc_feat, att_feat)
    elif len(batch[0]) == 2:
        fc_feat, trees = zip(*batch)
        fc_feat = torch.from_numpy(np.stack(fc_feat, axis=0))
        feats = (fc_feat, )
    else:
        raise ValueError

    root_lex = [tree.root.lex for tree in trees]
    root_lex = '(' + ' '.join(root_lex) + ')'
    root_node = TensorNode(Node(root_lex, 1, ""))
    root_node.word_idx = torch.from_numpy(
        np.stack([tree.root.word_idx for tree in trees], axis=0))

    batch_tree = Tree()
    batch_tree.root = root_node
    batch_tree.nodes.update({0: root_node})

    batch_tree_pad([n.root for n in trees], root_node, batch_tree)
    batch_tree.update()

    return feats + (batch_tree, )
Exemplo n.º 2
0
    def greedy_search(self, fc_feat, vocab, max_seq_length):
        h_r, probs_r = self.from_scratch(fc_feat)
        _, word_idx_r = probs_r.max(dim=1)
        lex_r = vocab.idx2word[word_idx_r.item()]
        word_idx_r = torch.LongTensor([word_idx_r])

        node_r = TensorNode(Node(lex_r, 1, ""))
        node_r.word_idx = word_idx_r
        node_r.h = h_r

        tree = Tree()
        tree.root = node_r
        tree.nodes.update({node_r.idx - 1: node_r})

        tree.model = self
        tree.vocab = vocab
        tree.max_seq_length = max_seq_length
        tree.bfs_breed()
        return tree
Exemplo n.º 3
0
def create_node(idx, lex, word_idx, h, depth):
    node = TensorNode(Node(lex, idx, ""))
    node.label = lex + '-' + str(idx)
    node.word_idx = word_idx.clone()
    node.h = (h[0].clone(), h[1].clone())
    node.depth = depth
    return node
Exemplo n.º 4
0
def batch_tree_pad(nodes, guard, tree):
    # nodes and tree are in the same position
    lc = [n.lc if n is not None else None for n in nodes]
    mc = [n.mc if n is not None else None for n in nodes]
    rc = [n.rc if n is not None else None for n in nodes]

    if not any(lc):
        return

    idx = len(tree.nodes) + 1

    lex = [n.lex if n is not None else '<PAD>' for n in lc]
    lex = '(' + ' '.join(lex) + ')'
    word_idx = [n.word_idx if n is not None else 3 for n in lc]
    new_lc = TensorNode(Node(lex, idx, ""))
    new_lc.word_idx = torch.LongTensor(word_idx)
    guard.lc = new_lc

    lex = [n.lex if n is not None else '<PAD>' for n in mc]
    lex = '(' + ' '.join(lex) + ')'
    word_idx = [n.word_idx if n is not None else 3 for n in mc]
    new_mc = TensorNode(Node(lex, idx + 1, ""))
    new_mc.word_idx = torch.LongTensor(word_idx)
    guard.mc = new_mc

    lex = [n.lex if n is not None else '<PAD>' for n in rc]
    lex = '(' + ' '.join(lex) + ')'
    word_idx = [n.word_idx if n is not None else 3 for n in rc]
    new_rc = TensorNode(Node(lex, idx + 2, ""))
    new_rc.word_idx = torch.LongTensor(word_idx)
    guard.rc = new_rc

    tree.nodes.update({new_lc.idx: new_lc})
    tree.nodes.update({new_mc.idx: new_mc})
    tree.nodes.update({new_rc.idx: new_rc})

    batch_tree_pad(lc, new_lc, tree)
    batch_tree_pad(mc, new_mc, tree)
    batch_tree_pad(rc, new_rc, tree)
Exemplo n.º 5
0
    def _sample(self, fc_feats, vocab, max_seq_length):
        fc_feats = self.fc_embed(fc_feats)

        states_r, probs_r = self.from_scratch(fc_feats)
        logprob, word_idx_r = probs_r.max(dim=1)
        lex_r = vocab.idx2word[word_idx_r.item()]
        word_idx_r = torch.LongTensor([word_idx_r])

        node_r = TensorNode(Node(lex_r, 1, ""))
        node_r.word_idx = word_idx_r
        node_r.h = states_r
        node_r.logprob = logprob.item()

        tree = Tree()
        tree.root = node_r
        tree.nodes.update({node_r.idx - 1: node_r})

        queue = Queue()
        queue.put(tree.root)
        print(tree.root.lex)
        while not queue.empty() and len(tree.nodes) <= max_seq_length:
            node = queue.get()

            if node.lex == '<EOB>':
                continue

            idx = len(tree.nodes) + 1
            lc = TensorNode(Node("", idx, ""))
            lc.parent = node
            self.complete_node(lc, vocab, fc_feats)
            node.lc = lc
            lc.depth = node.depth + 1

            mc = TensorNode(Node("", idx + 1, ""))
            mc.parent = node
            mc.left_brother = lc
            self.complete_node(mc, vocab, fc_feats)
            node.mc = mc
            mc.depth = node.depth + 1
            lc.right_brother = mc

            rc = TensorNode(Node("", idx + 2, ""))
            rc.parent = node
            rc.left_brother = mc
            self.complete_node(rc, vocab, fc_feats)
            node.rc = rc
            rc.depth = node.depth + 1
            mc.right_brother = rc

            tree.nodes.update({lc.idx - 1: lc})
            tree.nodes.update({mc.idx - 1: mc})
            tree.nodes.update({rc.idx - 1: rc})

            queue.put(lc)
            queue.put(mc)
            queue.put(rc)

        return tree
Exemplo n.º 6
0
    def beam_search(self,
                    fc_feat,
                    vocab,
                    max_seq_length,
                    global_beam_size,
                    local_beam_size,
                    depth_normalization_factor=0.8):

        (h_r, c_r), logprobs_r = self.from_scratch(fc_feat)
        ys, ix = torch.sort(logprobs_r, 1, True)
        candidates = []
        for c in range(global_beam_size):
            local_logprob = ys[0, c].item()
            candidates.append({'c': ix[0, c], 'p': local_logprob})
        candidate_trees = []

        for c in range(global_beam_size):
            lex = vocab.idx2word[candidates[c]['c'].item()]
            node = TensorNode(Node(lex, 1, ""))
            node.word_idx = candidates[c]['c']
            node.h = h_r.clone(), c_r.clone()
            node.logprob = candidates[c]['p']
            tree = Tree()
            tree.root = node
            tree.logprob = candidates[c]['p']
            tree.nodes.update({node.idx - 1: node})
            candidate_trees.append(tree)

        # for t in range((max_seq_length - 1) // 3):
        completed_sentences = []

        # for t in range(10):
        while True:
            new_candidates = []
            for tree_idx, tree in enumerate(candidate_trees):
                # print(tree.__str__())
                # print([_node.lex for _node in tree.nodes.values()])
                dead_tree = True
                for node in tree.nodes.values():
                    if node.lc is None and node.lex != '<EOB>':
                        dead_tree = False
                if dead_tree:
                    temp_candidate_dict = {
                        'p': tree.logprob,
                        'idx': tree_idx,
                        'dead_tree': True
                    }
                    new_candidates.append(temp_candidate_dict)
                    print("done sentence: {}".format(tree.__str__()))
                    continue

                new_candidates_per_tree = []
                for node in tree.nodes.values():

                    current_depth = node.depth

                    # find L(local size) groups children of this node via chain beam search
                    if node.lc is not None or node.lex == '<EOB>':
                        continue
                    local_candidates = []

                    (hs, cs), logprobs = self.beam_step(
                        state_a=(node.h[0].clone(), node.h[1].clone()),
                        x_a=node.word_idx.clone())
                    ys, ix = torch.sort(logprobs, 1, True)
                    for c in range(local_beam_size):
                        local_logprob = ys[0, c].item()
                        local_candidates.append({
                            'c': ix[0, c],
                            'p': local_logprob,
                            'seq_p': [local_logprob]
                        })

                    # print("input: father ({}), left sibling (None)".format(node.lex))
                    # print("output candidates: {}".format([vocab.idx2word[cdd['c'].item()] for cdd in local_candidates]))

                    m_local_candidates = []
                    for c in range(local_beam_size):
                        (_h, _c), _logprobs = self.beam_step(
                            state_a=(node.h[0].clone(), node.h[1].clone()),
                            x_a=node.word_idx.clone(),
                            state_f=(hs.clone(), cs.clone()),
                            x_f=local_candidates[c]['c'].clone())
                        _ys, _ix = torch.sort(_logprobs, 1, True)
                        for q in range(local_beam_size):
                            local_logprob = _ys[0, q].item()
                            entry = {
                                'c': [local_candidates[c]['c'], _ix[0, q]],
                                'p':
                                local_logprob + local_candidates[c]['p'],
                                'seq_p':
                                local_candidates[c]['seq_p'] + [local_logprob],
                                'state': (_h, _c)
                            }
                            m_local_candidates.append(entry)

                            # print("input: father ({}), left sibling ({})".format(node.lex, vocab.idx2word[local_candidates[c]['c'].item()]))
                            # print("output candidates: {}".format(
                            #     [vocab.idx2word[cdd['c'][1].item()] for cdd in m_local_candidates]))

                    m_local_candidates = sorted(m_local_candidates,
                                                key=lambda x: -x['p'])
                    m_local_candidates = m_local_candidates[:local_beam_size]

                    # print([{'c': cdd['c'], 'p': cdd['p'], 'state_len': len(cdd['state']), 'state_0_len': len(cdd['state'][0])} for cdd in m_local_candidates])

                    r_local_candidates = []
                    for c in range(local_beam_size):
                        f_state = (m_local_candidates[c]['state'][0].clone(),
                                   m_local_candidates[c]['state'][1].clone())
                        (_h, _c), _logprobs = self.beam_step(
                            state_a=(node.h[0].clone(), node.h[1].clone()),
                            x_a=node.word_idx.clone(),
                            state_f=f_state,
                            x_f=m_local_candidates[c]['c'][-1])
                        _ys, _ix = torch.sort(_logprobs, 1, True)
                        for q in range(local_beam_size):
                            local_logprob = _ys[0, q].item()
                            entry = {
                                'c': [
                                    _.clone()
                                    for _ in m_local_candidates[c]['c']
                                ],
                                'p':
                                local_logprob + m_local_candidates[c]['p'],
                                'seq_p':
                                m_local_candidates[c]['seq_p'] +
                                [local_logprob],
                                'state':
                                [(m_local_candidates[c]['state'][0].clone(),
                                  m_local_candidates[c]['state'][1].clone()),
                                 (_h.clone(), _c.clone())]
                            }
                            entry['c'].append(_ix[0, q].clone())
                            r_local_candidates.append(entry)

                    r_local_candidates = sorted(r_local_candidates,
                                                key=lambda x: -x['p'])
                    r_local_candidates = r_local_candidates[:local_beam_size]

                    # print([{'c': cdd['c'], 'p': cdd['p'], 'state_len': len(cdd['state']), 'state_0_len': len(cdd['state'][0])} for cdd in r_local_candidates])

                    for candidate in r_local_candidates:
                        candidate['state'].insert(0, (hs, cs))
                        # this should be proceed after selecting top-L's combination
                        # candidate['p'] += tree.logprob
                        # candidate['idx'] = tree_idx
                        candidate['fid'] = node.idx
                        # candidate['dead_tree'] = False

                    new_candidates_per_tree.append(r_local_candidates)

                # Now we get a list with length of (len(nodes) - len(<EOB>))
                # And the number of combination is local_batch_size ** length

                # inefficient but accurate method, exhaustivity
                # still using beam search

                combination_candidates = []
                for i in range(len(new_candidates_per_tree)):
                    if i == 0:
                        for j in range(local_beam_size):
                            combination_candidates.append({
                                'p':
                                new_candidates_per_tree[i][j]['p'],
                                'seq': [j]
                            })
                        continue
                    new_combination_candidates = []
                    for p in range(local_beam_size):
                        for q in range(local_beam_size):
                            local_combination_logprob = new_candidates_per_tree[
                                i][q]['p'] + combination_candidates[p]['p']
                            local_seq = combination_candidates[p]['seq'] + [q]
                            new_combination_candidates.append({
                                'p': local_combination_logprob,
                                'seq': local_seq
                            })
                    new_combination_candidates = sorted(
                        new_combination_candidates, key=lambda x: -x['p'])
                    new_combination_candidates = new_combination_candidates[:
                                                                            local_beam_size]
                    combination_candidates = new_combination_candidates

                # print(len(combination_candidates))

                done_combination_candidates = []
                for i in range(local_beam_size):
                    done_combination_candidates.append({
                        'p':
                        combination_candidates[i]['p'] + tree.logprob,
                        'combination': [
                            new_candidates_per_tree[i][wi] for i, wi in
                            enumerate(combination_candidates[i]['seq'])
                        ],
                        'dead_tree':
                        False,
                        'idx':
                        tree_idx
                    })

                new_candidates.extend(done_combination_candidates)

            # print(len(new_candidates))
            new_candidates = sorted(new_candidates, key=lambda x: -x['p'])
            new_candidates = new_candidates[:global_beam_size]

            new_candidate_trees = []
            for _candidate in new_candidates:
                if _candidate['dead_tree']:
                    new_candidate_trees.append(
                        candidate_trees[tree_idx].clone())
                    continue

                tree_idx = _candidate['idx']
                new_tree = candidate_trees[tree_idx].clone()

                for candidate in _candidate['combination']:
                    f_node = new_tree.nodes[candidate['fid'] - 1]

                    # print(f_node.__repr__())
                    idx = len(new_tree.nodes) + 1
                    lex = [vocab.idx2word[_.item()] for _ in candidate['c']]
                    seq_p = candidate['seq_p']

                    lc = TensorNode(Node(lex[0], idx, ""))
                    lc.parent = f_node
                    lc.word_idx = candidate['c'][0].clone()
                    f_node.lc = lc
                    lc.depth = f_node.depth + 1
                    lc.h = tuple_clone(candidate['state'][0])
                    lc.logprob = seq_p[0]

                    mc = TensorNode(Node(lex[1], idx + 1, ""))
                    mc.parent = f_node
                    mc.left_brother = lc
                    mc.word_idx = candidate['c'][1].clone()
                    f_node.mc = mc
                    mc.depth = f_node.depth + 1
                    mc.h = tuple_clone(candidate['state'][1])
                    lc.right_brother = mc
                    mc.logprob = seq_p[1]

                    rc = TensorNode(Node(lex[2], idx + 2, ""))
                    rc.parent = f_node
                    rc.left_brother = mc
                    rc.word_idx = candidate['c'][2].clone()
                    f_node.rc = rc
                    rc.depth = f_node.depth + 1
                    rc.h = tuple_clone(candidate['state'][2])
                    mc.right_brother = rc
                    rc.logprob = seq_p[2]

                    new_tree.nodes.update({lc.idx - 1: lc})
                    new_tree.nodes.update({mc.idx - 1: mc})
                    new_tree.nodes.update({rc.idx - 1: rc})

                    new_tree.logprob += candidate['p']

                # with open('beam_search.dot', 'w') as f:
                #     f.write(new_tree.graphviz())
                #
                # exit(0)
                dead_tree = True
                for node in new_tree.nodes.values():
                    if node.lc is None and node.lex != '<EOB>':
                        dead_tree = False

                if dead_tree:
                    # new_tree.logprob /= len(new_tree.nodes)**3.0
                    completed_sentences.append(new_tree)
                else:
                    new_candidate_trees.append(new_tree)

            candidate_trees = new_candidate_trees

            # if len(candidate_trees) == 0:
            #     break
            break_flag = True
            for tree in candidate_trees:
                if len(tree.nodes) < max_seq_length:
                    break_flag = False
            if break_flag:
                break

        return candidate_trees, completed_sentences
Exemplo n.º 7
0
    def _sample(self, fc_feats, vocab, max_seq_length):
        state_a = self.init_state(1)
        xt_a = self.fc_embed(fc_feats)
        state_f = self.init_state(1)
        xt_f = self.init_input(1)

        output, state = self.md_lstm(xt_a, xt_f, state_a, state_f)
        logits = self.logit(output)
        logprobs = F.log_softmax(logits, dim=1)

        logprob, word_idx_r = logprobs.max(dim=1)
        word_idx_r = torch.LongTensor([word_idx_r])

        node_r = TensorNode(Node("", 1, ""))
        node_r.word_idx = word_idx_r
        node_r.word_embed = self.embed(word_idx_r)
        # node_r.h = output, state
        node_r.h = state
        node_r.logprob = logprob.item()

        tree = Tree()
        tree.root = node_r
        tree.nodes.update({node_r.idx - 1: node_r})

        eob_idx = vocab('<EOB>')

        queue = Queue()
        queue.put(tree.root)
        while not queue.empty() and len(tree.nodes) <= max_seq_length:
            node = queue.get()

            if node.word_idx.item() == eob_idx:
                continue

            idx = len(tree.nodes) + 1
            lc = TensorNode(Node("", idx, ""))
            lc.parent = node
            self.complete_node(lc)
            node.lc = lc
            lc.depth = node.depth + 1

            mc = TensorNode(Node("", idx + 1, ""))
            mc.parent = node
            mc.left_brother = lc
            self.complete_node(mc)
            node.mc = mc
            mc.depth = node.depth + 1
            lc.right_brother = mc

            rc = TensorNode(Node("", idx + 2, ""))
            rc.parent = node
            rc.left_brother = mc
            self.complete_node(rc)
            node.rc = rc
            rc.depth = node.depth + 1
            mc.right_brother = rc

            tree.nodes.update({lc.idx - 1: lc})
            tree.nodes.update({mc.idx - 1: mc})
            tree.nodes.update({rc.idx - 1: rc})

            queue.put(lc)
            queue.put(mc)
            queue.put(rc)

        for node in tree.nodes.values():
            node.lex = vocab.idx2word[node.word_idx.item()]
            node.label = node.lex + '-' + str(node.idx)
        return tree