def collate_fn(batch): if len(batch[0]) == 3: fc_feat, att_feat, trees = zip(*batch) fc_feat, att_feat = torch.from_numpy(np.stack( fc_feat, axis=0)), torch.from_numpy(np.stack(att_feat, axis=0)) feats = (fc_feat, att_feat) elif len(batch[0]) == 2: fc_feat, trees = zip(*batch) fc_feat = torch.from_numpy(np.stack(fc_feat, axis=0)) feats = (fc_feat, ) else: raise ValueError root_lex = [tree.root.lex for tree in trees] root_lex = '(' + ' '.join(root_lex) + ')' root_node = TensorNode(Node(root_lex, 1, "")) root_node.word_idx = torch.from_numpy( np.stack([tree.root.word_idx for tree in trees], axis=0)) batch_tree = Tree() batch_tree.root = root_node batch_tree.nodes.update({0: root_node}) batch_tree_pad([n.root for n in trees], root_node, batch_tree) batch_tree.update() return feats + (batch_tree, )
def greedy_search(self, fc_feat, vocab, max_seq_length): h_r, probs_r = self.from_scratch(fc_feat) _, word_idx_r = probs_r.max(dim=1) lex_r = vocab.idx2word[word_idx_r.item()] word_idx_r = torch.LongTensor([word_idx_r]) node_r = TensorNode(Node(lex_r, 1, "")) node_r.word_idx = word_idx_r node_r.h = h_r tree = Tree() tree.root = node_r tree.nodes.update({node_r.idx - 1: node_r}) tree.model = self tree.vocab = vocab tree.max_seq_length = max_seq_length tree.bfs_breed() return tree
def create_node(idx, lex, word_idx, h, depth): node = TensorNode(Node(lex, idx, "")) node.label = lex + '-' + str(idx) node.word_idx = word_idx.clone() node.h = (h[0].clone(), h[1].clone()) node.depth = depth return node
def batch_tree_pad(nodes, guard, tree): # nodes and tree are in the same position lc = [n.lc if n is not None else None for n in nodes] mc = [n.mc if n is not None else None for n in nodes] rc = [n.rc if n is not None else None for n in nodes] if not any(lc): return idx = len(tree.nodes) + 1 lex = [n.lex if n is not None else '<PAD>' for n in lc] lex = '(' + ' '.join(lex) + ')' word_idx = [n.word_idx if n is not None else 3 for n in lc] new_lc = TensorNode(Node(lex, idx, "")) new_lc.word_idx = torch.LongTensor(word_idx) guard.lc = new_lc lex = [n.lex if n is not None else '<PAD>' for n in mc] lex = '(' + ' '.join(lex) + ')' word_idx = [n.word_idx if n is not None else 3 for n in mc] new_mc = TensorNode(Node(lex, idx + 1, "")) new_mc.word_idx = torch.LongTensor(word_idx) guard.mc = new_mc lex = [n.lex if n is not None else '<PAD>' for n in rc] lex = '(' + ' '.join(lex) + ')' word_idx = [n.word_idx if n is not None else 3 for n in rc] new_rc = TensorNode(Node(lex, idx + 2, "")) new_rc.word_idx = torch.LongTensor(word_idx) guard.rc = new_rc tree.nodes.update({new_lc.idx: new_lc}) tree.nodes.update({new_mc.idx: new_mc}) tree.nodes.update({new_rc.idx: new_rc}) batch_tree_pad(lc, new_lc, tree) batch_tree_pad(mc, new_mc, tree) batch_tree_pad(rc, new_rc, tree)
def _sample(self, fc_feats, vocab, max_seq_length): fc_feats = self.fc_embed(fc_feats) states_r, probs_r = self.from_scratch(fc_feats) logprob, word_idx_r = probs_r.max(dim=1) lex_r = vocab.idx2word[word_idx_r.item()] word_idx_r = torch.LongTensor([word_idx_r]) node_r = TensorNode(Node(lex_r, 1, "")) node_r.word_idx = word_idx_r node_r.h = states_r node_r.logprob = logprob.item() tree = Tree() tree.root = node_r tree.nodes.update({node_r.idx - 1: node_r}) queue = Queue() queue.put(tree.root) print(tree.root.lex) while not queue.empty() and len(tree.nodes) <= max_seq_length: node = queue.get() if node.lex == '<EOB>': continue idx = len(tree.nodes) + 1 lc = TensorNode(Node("", idx, "")) lc.parent = node self.complete_node(lc, vocab, fc_feats) node.lc = lc lc.depth = node.depth + 1 mc = TensorNode(Node("", idx + 1, "")) mc.parent = node mc.left_brother = lc self.complete_node(mc, vocab, fc_feats) node.mc = mc mc.depth = node.depth + 1 lc.right_brother = mc rc = TensorNode(Node("", idx + 2, "")) rc.parent = node rc.left_brother = mc self.complete_node(rc, vocab, fc_feats) node.rc = rc rc.depth = node.depth + 1 mc.right_brother = rc tree.nodes.update({lc.idx - 1: lc}) tree.nodes.update({mc.idx - 1: mc}) tree.nodes.update({rc.idx - 1: rc}) queue.put(lc) queue.put(mc) queue.put(rc) return tree
def beam_search(self, fc_feat, vocab, max_seq_length, global_beam_size, local_beam_size, depth_normalization_factor=0.8): (h_r, c_r), logprobs_r = self.from_scratch(fc_feat) ys, ix = torch.sort(logprobs_r, 1, True) candidates = [] for c in range(global_beam_size): local_logprob = ys[0, c].item() candidates.append({'c': ix[0, c], 'p': local_logprob}) candidate_trees = [] for c in range(global_beam_size): lex = vocab.idx2word[candidates[c]['c'].item()] node = TensorNode(Node(lex, 1, "")) node.word_idx = candidates[c]['c'] node.h = h_r.clone(), c_r.clone() node.logprob = candidates[c]['p'] tree = Tree() tree.root = node tree.logprob = candidates[c]['p'] tree.nodes.update({node.idx - 1: node}) candidate_trees.append(tree) # for t in range((max_seq_length - 1) // 3): completed_sentences = [] # for t in range(10): while True: new_candidates = [] for tree_idx, tree in enumerate(candidate_trees): # print(tree.__str__()) # print([_node.lex for _node in tree.nodes.values()]) dead_tree = True for node in tree.nodes.values(): if node.lc is None and node.lex != '<EOB>': dead_tree = False if dead_tree: temp_candidate_dict = { 'p': tree.logprob, 'idx': tree_idx, 'dead_tree': True } new_candidates.append(temp_candidate_dict) print("done sentence: {}".format(tree.__str__())) continue new_candidates_per_tree = [] for node in tree.nodes.values(): current_depth = node.depth # find L(local size) groups children of this node via chain beam search if node.lc is not None or node.lex == '<EOB>': continue local_candidates = [] (hs, cs), logprobs = self.beam_step( state_a=(node.h[0].clone(), node.h[1].clone()), x_a=node.word_idx.clone()) ys, ix = torch.sort(logprobs, 1, True) for c in range(local_beam_size): local_logprob = ys[0, c].item() local_candidates.append({ 'c': ix[0, c], 'p': local_logprob, 'seq_p': [local_logprob] }) # print("input: father ({}), left sibling (None)".format(node.lex)) # print("output candidates: {}".format([vocab.idx2word[cdd['c'].item()] for cdd in local_candidates])) m_local_candidates = [] for c in range(local_beam_size): (_h, _c), _logprobs = self.beam_step( state_a=(node.h[0].clone(), node.h[1].clone()), x_a=node.word_idx.clone(), state_f=(hs.clone(), cs.clone()), x_f=local_candidates[c]['c'].clone()) _ys, _ix = torch.sort(_logprobs, 1, True) for q in range(local_beam_size): local_logprob = _ys[0, q].item() entry = { 'c': [local_candidates[c]['c'], _ix[0, q]], 'p': local_logprob + local_candidates[c]['p'], 'seq_p': local_candidates[c]['seq_p'] + [local_logprob], 'state': (_h, _c) } m_local_candidates.append(entry) # print("input: father ({}), left sibling ({})".format(node.lex, vocab.idx2word[local_candidates[c]['c'].item()])) # print("output candidates: {}".format( # [vocab.idx2word[cdd['c'][1].item()] for cdd in m_local_candidates])) m_local_candidates = sorted(m_local_candidates, key=lambda x: -x['p']) m_local_candidates = m_local_candidates[:local_beam_size] # print([{'c': cdd['c'], 'p': cdd['p'], 'state_len': len(cdd['state']), 'state_0_len': len(cdd['state'][0])} for cdd in m_local_candidates]) r_local_candidates = [] for c in range(local_beam_size): f_state = (m_local_candidates[c]['state'][0].clone(), m_local_candidates[c]['state'][1].clone()) (_h, _c), _logprobs = self.beam_step( state_a=(node.h[0].clone(), node.h[1].clone()), x_a=node.word_idx.clone(), state_f=f_state, x_f=m_local_candidates[c]['c'][-1]) _ys, _ix = torch.sort(_logprobs, 1, True) for q in range(local_beam_size): local_logprob = _ys[0, q].item() entry = { 'c': [ _.clone() for _ in m_local_candidates[c]['c'] ], 'p': local_logprob + m_local_candidates[c]['p'], 'seq_p': m_local_candidates[c]['seq_p'] + [local_logprob], 'state': [(m_local_candidates[c]['state'][0].clone(), m_local_candidates[c]['state'][1].clone()), (_h.clone(), _c.clone())] } entry['c'].append(_ix[0, q].clone()) r_local_candidates.append(entry) r_local_candidates = sorted(r_local_candidates, key=lambda x: -x['p']) r_local_candidates = r_local_candidates[:local_beam_size] # print([{'c': cdd['c'], 'p': cdd['p'], 'state_len': len(cdd['state']), 'state_0_len': len(cdd['state'][0])} for cdd in r_local_candidates]) for candidate in r_local_candidates: candidate['state'].insert(0, (hs, cs)) # this should be proceed after selecting top-L's combination # candidate['p'] += tree.logprob # candidate['idx'] = tree_idx candidate['fid'] = node.idx # candidate['dead_tree'] = False new_candidates_per_tree.append(r_local_candidates) # Now we get a list with length of (len(nodes) - len(<EOB>)) # And the number of combination is local_batch_size ** length # inefficient but accurate method, exhaustivity # still using beam search combination_candidates = [] for i in range(len(new_candidates_per_tree)): if i == 0: for j in range(local_beam_size): combination_candidates.append({ 'p': new_candidates_per_tree[i][j]['p'], 'seq': [j] }) continue new_combination_candidates = [] for p in range(local_beam_size): for q in range(local_beam_size): local_combination_logprob = new_candidates_per_tree[ i][q]['p'] + combination_candidates[p]['p'] local_seq = combination_candidates[p]['seq'] + [q] new_combination_candidates.append({ 'p': local_combination_logprob, 'seq': local_seq }) new_combination_candidates = sorted( new_combination_candidates, key=lambda x: -x['p']) new_combination_candidates = new_combination_candidates[: local_beam_size] combination_candidates = new_combination_candidates # print(len(combination_candidates)) done_combination_candidates = [] for i in range(local_beam_size): done_combination_candidates.append({ 'p': combination_candidates[i]['p'] + tree.logprob, 'combination': [ new_candidates_per_tree[i][wi] for i, wi in enumerate(combination_candidates[i]['seq']) ], 'dead_tree': False, 'idx': tree_idx }) new_candidates.extend(done_combination_candidates) # print(len(new_candidates)) new_candidates = sorted(new_candidates, key=lambda x: -x['p']) new_candidates = new_candidates[:global_beam_size] new_candidate_trees = [] for _candidate in new_candidates: if _candidate['dead_tree']: new_candidate_trees.append( candidate_trees[tree_idx].clone()) continue tree_idx = _candidate['idx'] new_tree = candidate_trees[tree_idx].clone() for candidate in _candidate['combination']: f_node = new_tree.nodes[candidate['fid'] - 1] # print(f_node.__repr__()) idx = len(new_tree.nodes) + 1 lex = [vocab.idx2word[_.item()] for _ in candidate['c']] seq_p = candidate['seq_p'] lc = TensorNode(Node(lex[0], idx, "")) lc.parent = f_node lc.word_idx = candidate['c'][0].clone() f_node.lc = lc lc.depth = f_node.depth + 1 lc.h = tuple_clone(candidate['state'][0]) lc.logprob = seq_p[0] mc = TensorNode(Node(lex[1], idx + 1, "")) mc.parent = f_node mc.left_brother = lc mc.word_idx = candidate['c'][1].clone() f_node.mc = mc mc.depth = f_node.depth + 1 mc.h = tuple_clone(candidate['state'][1]) lc.right_brother = mc mc.logprob = seq_p[1] rc = TensorNode(Node(lex[2], idx + 2, "")) rc.parent = f_node rc.left_brother = mc rc.word_idx = candidate['c'][2].clone() f_node.rc = rc rc.depth = f_node.depth + 1 rc.h = tuple_clone(candidate['state'][2]) mc.right_brother = rc rc.logprob = seq_p[2] new_tree.nodes.update({lc.idx - 1: lc}) new_tree.nodes.update({mc.idx - 1: mc}) new_tree.nodes.update({rc.idx - 1: rc}) new_tree.logprob += candidate['p'] # with open('beam_search.dot', 'w') as f: # f.write(new_tree.graphviz()) # # exit(0) dead_tree = True for node in new_tree.nodes.values(): if node.lc is None and node.lex != '<EOB>': dead_tree = False if dead_tree: # new_tree.logprob /= len(new_tree.nodes)**3.0 completed_sentences.append(new_tree) else: new_candidate_trees.append(new_tree) candidate_trees = new_candidate_trees # if len(candidate_trees) == 0: # break break_flag = True for tree in candidate_trees: if len(tree.nodes) < max_seq_length: break_flag = False if break_flag: break return candidate_trees, completed_sentences
def _sample(self, fc_feats, vocab, max_seq_length): state_a = self.init_state(1) xt_a = self.fc_embed(fc_feats) state_f = self.init_state(1) xt_f = self.init_input(1) output, state = self.md_lstm(xt_a, xt_f, state_a, state_f) logits = self.logit(output) logprobs = F.log_softmax(logits, dim=1) logprob, word_idx_r = logprobs.max(dim=1) word_idx_r = torch.LongTensor([word_idx_r]) node_r = TensorNode(Node("", 1, "")) node_r.word_idx = word_idx_r node_r.word_embed = self.embed(word_idx_r) # node_r.h = output, state node_r.h = state node_r.logprob = logprob.item() tree = Tree() tree.root = node_r tree.nodes.update({node_r.idx - 1: node_r}) eob_idx = vocab('<EOB>') queue = Queue() queue.put(tree.root) while not queue.empty() and len(tree.nodes) <= max_seq_length: node = queue.get() if node.word_idx.item() == eob_idx: continue idx = len(tree.nodes) + 1 lc = TensorNode(Node("", idx, "")) lc.parent = node self.complete_node(lc) node.lc = lc lc.depth = node.depth + 1 mc = TensorNode(Node("", idx + 1, "")) mc.parent = node mc.left_brother = lc self.complete_node(mc) node.mc = mc mc.depth = node.depth + 1 lc.right_brother = mc rc = TensorNode(Node("", idx + 2, "")) rc.parent = node rc.left_brother = mc self.complete_node(rc) node.rc = rc rc.depth = node.depth + 1 mc.right_brother = rc tree.nodes.update({lc.idx - 1: lc}) tree.nodes.update({mc.idx - 1: mc}) tree.nodes.update({rc.idx - 1: rc}) queue.put(lc) queue.put(mc) queue.put(rc) for node in tree.nodes.values(): node.lex = vocab.idx2word[node.word_idx.item()] node.label = node.lex + '-' + str(node.idx) return tree