def test_l3_kbest(): params, lengths = params_l3() dist = SentCFG(params, lengths=lengths) _, _, _, spans = dist.argmax spans, trees = extract_parses(spans, lengths) best_trees = "((0 1) 2)" best_spans = [(0, 1, 2), (0, 2, 2)] assert spans[0] == best_spans assert trees[0] == best_trees _, _, _, spans = dist.topk(4) size = (1, 0) + tuple(range(2, spans.dim())) spans = spans.permute(size) spans, trees = extract_topk(spans, lengths) best_trees = "((0 1) 2)" best_spans = [ [(0, 1, 2), (0, 2, 2)], [(0, 1, 2), (0, 2, 2)], [(0, 1, 1), (0, 2, 2)], [(0, 1, 1), (0, 2, 2)], ] for i, (span, tree) in enumerate(zip(spans, trees)): assert span == best_spans[i] assert tree == best_trees
def train(): # model.train() losses = [] for epoch in range(2): for i, ex in enumerate(train_iter): opt.zero_grad() words, lengths = ex.word N, batch = words.shape words = words.long() params = model(words.cuda().transpose(0, 1)) dist = SentCFG(params, lengths=lengths) loss = dist.partition.mean() (-loss).backward() losses.append(loss.detach()) torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) opt.step() if i > 100: break if i % 100 == 1: print(-torch.tensor(losses).mean(), words.shape) losses = [] if args.debug: print(f"saving to {args.debug}...") torch.save(params[0], args.debug)
def train(self, niter=1): for _ in range(niter): losses = [] self.opt.zero_grad() params = self.model(self.words) dist = SentCFG(params, lengths=self.lengths) loss = dist.partition.mean() (-loss).backward() losses.append(loss.detach()) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0) self.opt.step()
def forward_parser(self, captions, lengths): params, kl = self.parser(captions) dist = SentCFG(params, lengths=lengths) the_spans = dist.argmax[-1] argmax_spans, trees, lprobs = utils.extract_parses(the_spans, lengths.tolist(), inc=0) ll = dist.partition nll = -ll kl = torch.zeros_like(nll) if kl is None else kl return nll, kl, argmax_spans, trees, lprobs
def train(self, niter=1): losses = [] for i, ex in enumerate(self.train_iter): if i == niter: break self.opt.zero_grad() words, lengths = ex.word words = words.long() params = self.model(words.to(device=self.device).transpose(0, 1)) dist = SentCFG(params, lengths=lengths) loss = dist.partition.mean() (-loss).backward() losses.append(loss.detach()) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0) self.opt.step()
def forward_txt_parser(self, input, lengths): params, kl = self.txt_parser(input) dist = SentCFG(params, lengths=lengths) the_spans = dist.argmax[-1] argmax_spans, trees, lprobs = utils.extract_parses(the_spans, lengths.tolist(), inc=0) txt_outputs = self.txt_enc(input, lengths) ll, span_margs = dist.inside_im nll = -ll kl = torch.zeros_like(nll) if kl is None else kl return txt_outputs, nll, kl, span_margs, argmax_spans, trees, lprobs
def eval_trees(args): checkpoint = torch.load(args.model, map_location='cpu') opt = checkpoint['opt'] use_mean = True # load vocabulary used by the model data_path = args.data_path #data_path = getattr(opt, "data_path", args.data_path) vocab_name = getattr(opt, "vocab_name", args.vocab_name) vocab = pickle.load(open(os.path.join(data_path, vocab_name), 'rb')) checkpoint['word2idx'] = vocab.word2idx opt.vocab_size = len(vocab) parser = checkpoint['model'] parser = make_model(parser, opt) parser.cuda() parser.eval() batch_size = 5 prefix = args.prefix print('Loading dataset', data_path + prefix + args.split) data_loader = data.eval_data_iter(data_path, prefix + args.split, vocab, batch_size=batch_size) # stats trees = list() n_word, n_sent = 0, 0 per_label_f1 = defaultdict(list) by_length_f1 = defaultdict(list) sent_f1, corpus_f1 = [], [0., 0., 0.] total_ll, total_kl, total_bc, total_h = 0., 0., 0., 0. pred_out = open(args.out_file, "w") for i, (captions, lengths, spans, labels, tags, ids) in enumerate(data_loader): lengths = torch.tensor(lengths).long() if isinstance(lengths, list) else lengths if torch.cuda.is_available(): lengths = lengths.cuda() captions = captions.cuda() params, kl = parser(captions, lengths, use_mean=use_mean) dist = SentCFG(params, lengths=lengths) arg_spans = dist.argmax[-1] argmax_spans, _, _ = utils.extract_parses(arg_spans, lengths.tolist(), inc=0) candidate_trees = list() bsize = captions.shape[0] n_word += (lengths + 1).sum().item() n_sent += bsize for b in range(bsize): max_len = lengths[b].item() pred = [(a[0], a[1]) for a in argmax_spans[b] if a[0] != a[1]] pred_set = set(pred[:-1]) gold = [(l, r) for l, r in spans[b] if l != r] gold_set = set(gold[:-1]) ccaption = captions[b].tolist()[:max_len] sent = [ vocab.idx2word[int(word)] for _, word in enumerate(ccaption) ] iitem = (sent, gold, labels, pred) json.dump(iitem, pred_out) pred_out.write("\n") tp, fp, fn = utils.get_stats(pred_set, gold_set) corpus_f1[0] += tp corpus_f1[1] += fp corpus_f1[2] += fn overlap = pred_set.intersection(gold_set) prec = float(len(overlap)) / (len(pred_set) + 1e-8) reca = float(len(overlap)) / (len(gold_set) + 1e-8) if len(gold_set) == 0: reca = 1. if len(pred_set) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) sent_f1.append(f1) word_tree = build_parse(argmax_spans[b], captions[b].tolist(), vocab) candidate_trees.append(word_tree) for j, gold_span in enumerate(gold[:-1]): label = labels[b][j] label = re.split("=|-", label)[0] per_label_f1.setdefault(label, [0., 0.]) per_label_f1[label][0] += 1 lspan = gold_span[1] - gold_span[0] + 1 by_length_f1.setdefault(lspan, [0., 0.]) by_length_f1[lspan][0] += 1 if gold_span in pred_set: per_label_f1[label][1] += 1 by_length_f1[lspan][1] += 1 appended_trees = ['' for _ in range(len(ids))] for j in range(len(ids)): tree = candidate_trees[j] appended_trees[ids[j] - min(ids)] = tree for tree in appended_trees: #print(tree) pass trees.extend(appended_trees) #if i == 50: break tp, fp, fn = corpus_f1 prec = tp / (tp + fp) recall = tp / (tp + fn) corpus_f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0. sent_f1 = np.mean(np.array(sent_f1)) recon_ppl = np.exp(total_ll / n_word) ppl_elbo = np.exp((total_ll + total_kl) / n_word) kl = total_kl / n_sent info = '\nReconPPL: {:.2f}, KL: {:.4f}, PPL (Upper Bound): {:.2f}\n' + \ 'Corpus F1: {:.2f}, Sentence F1: {:.2f}' info = info.format(recon_ppl, kl, ppl_elbo, corpus_f1 * 100, sent_f1 * 100) print(info) f1_ids = ["CF1", "SF1", "NP", "VP", "PP", "SBAR", "ADJP", "ADVP"] f1s = {"CF1": corpus_f1, "SF1": sent_f1} print("\nPER-LABEL-F1 (label, acc)\n") for k, v in per_label_f1.items(): print("{}\t{:.4f} = {}/{}".format(k, v[1] / v[0], v[1], v[0])) f1s[k] = v[1] / v[0] f1s = ['{:.2f}'.format(float(f1s[x]) * 100) for x in f1_ids] print("\t".join(f1_ids)) print("\t".join(f1s)) acc = [] print("\nPER-LENGTH-F1 (length, acc)\n") xx = sorted(list(by_length_f1.items()), key=lambda x: x[0]) for k, v in xx: print("{}\t{:.4f} = {}/{}".format(k, v[1] / v[0], v[1], v[0])) if v[0] >= 5: acc.append((str(k), '{:.2f}'.format(v[1] / v[0]))) k = [x for x, _ in acc] v = [x for _, x in acc] print("\t".join(k)) print("\t".join(v)) pred_out.close() return trees