Exemple #1
0
def rerank(fm, args, network=None):
    if network is None:
        network = torch.load(args.model)
        print('Loaded model from: {}'.format(args.model))

    gold = PhraseTree.load_trees(args.gold)
    kbest = PhraseTree.load_kbests(args.kbest, fm)
    res = []
    print('reranking')
    for onebest, g in zip(kbest, gold):
        # gold_score = network.force_decoding(fm.gold_data(g), fm)
        # print(gold_score)
        scores = np.zeros(len(onebest), dtype='float32')
        for i, data in enumerate(onebest):
            scores[i] = force_decoding(network, data, fm)
        maxid = np.argmax(scores)
        # print(scores)
        # print(maxid)
        res.append(onebest[maxid])

    accuracy = FScore()
    baseline = FScore()
    for p, g in zip(kbest, gold):
        local_accuracy = p[0]['tree'].compare(g)
        baseline += local_accuracy

    for p, g in zip(res, gold):
        local_accuracy = p['tree'].compare(g)
        accuracy += local_accuracy
    print(accuracy)
    print(baseline)

    return accuracy
Exemple #2
0
def test(fm, args):
    test_trees = PhraseTree.load_trees(args.test)
    print('Loaded test trees from {}'.format(args.test))
    network = torch.load(args.model)
    print('Loaded model from: {}'.format(args.model))
    accuracy = Parser.evaluate_corpus(test_trees, fm, network)
    print('Accuracy: {}'.format(accuracy))
Exemple #3
0
    def vocab_init(tree_file, verbose=True):
        """
        Learn vocabulary from file of strings.
        """
        word_freq = defaultdict(int)
        tag_freq = defaultdict(int)
        label_freq = defaultdict(int)

        trees = PhraseTree.load_trees(tree_file)

        for i, tree in enumerate(trees):
            for (word, tag) in tree.sentence:
                word_freq[word] += 1
                tag_freq[tag] += 1

            for action in Parser.gold_actions(tree):
                if action.startswith('label-'):
                    label = action[6:]
                    label_freq[label] += 1

            if verbose:
                print('\rTree {}'.format(i), end='')
                sys.stdout.flush()

        if verbose:
            print('\r', end='')

        words = [
            Vocab.UNK,
            Vocab.START,
            Vocab.STOP,
        ] + sorted(word_freq)
        wdict = OrderedDict((w, i) for (i, w) in enumerate(words))

        tags = [
            Vocab.UNK,
            Vocab.START,
            Vocab.STOP,
        ] + sorted(tag_freq)
        tdict = OrderedDict((t, i) for (i, t) in enumerate(tags))

        labels = sorted(label_freq)
        ldict = OrderedDict((l, i) for (i, l) in enumerate(labels))

        if verbose:
            print('Loading features from {}'.format(tree_file))
            print('({} words, {} tags, {} nonterminal-chains)'.format(
                len(wdict),
                len(tdict),
                len(ldict),
            ))

        return {
            'wdict': wdict,
            'word_freq': word_freq,
            'tdict': tdict,
            'ldict': ldict,
        }
Exemple #4
0
 def gold_data_from_file(self, fname):
     """
     Static oracle for file.
     """
     trees = PhraseTree.load_trees(fname)
     result = []
     for tree in trees:
         sentence_data = self.gold_data(tree)
         result.append(sentence_data)
     return result
Exemple #5
0
 def write_raw_predicted(fname, sentences, fm, network):
     f = open(fname, 'w')
     for sentence in sentences:
         predicted = Parser.parse(sentence, fm, network)
         topped = PhraseTree(
             symbol='TOP',
             children=[predicted],
             sentence=predicted.sentence,
         )
         f.write(str(topped))
         f.write('\n')
     f.close()
Exemple #6
0
 def write_predicted(fname, trees, fm, network):
     """
     Input trees being used only to carry sentences.
     """
     f = open(fname, 'w')
     accuracy = FScore()
     for tree in trees:
         predicted = Parser.parse(tree.sentence, fm, network)
         local_accuracy = predicted.compare(tree)
         accuracy += local_accuracy
         topped = PhraseTree(
             symbol='TOP',
             children=[predicted],
             sentence=predicted.sentence,
         )
         f.write(str(topped))
         f.write('\n')
     f.close()
     return accuracy
Exemple #7
0
def train(fm, args):
    train_data_file = args.train
    dev_data_file = args.dev
    epochs = args.epochs
    batch_size = args.batch_size
    unk_param = args.unk_param
    alpha = args.alpha
    beta = args.beta
    model_save_file = args.model

    print("this is train mode")
    start_time = time.time()

    network = Network(fm, args)

    optimizer = optimize.Adadelta(network.parameters(), eps=1e-7, rho=0.99)
    if GlobalNames.use_gpu:
        network.cuda()

    training_data = fm.gold_data_from_file(train_data_file)
    num_batches = -(-len(training_data) // batch_size)
    print('Loaded {} training sentences ({} batches of size {})!'.format(
        len(training_data),
        num_batches,
        batch_size,
    ))
    parse_every = -(-num_batches // 4)

    dev_trees = PhraseTree.load_trees(dev_data_file)
    print('Loaded {} validation trees!'.format(len(dev_trees)))

    best_acc = FScore()

    for epoch in range(1, epochs + 1):
        print('........... epoch {} ...........'.format(epoch))

        total_cost = 0.0
        total_states = 0
        training_acc = FScore()

        np.random.shuffle(training_data)

        for b in range(num_batches):
            network.zero_grad()
            batch = training_data[(b * batch_size):((b + 1) * batch_size)]
            batch_loss = None
            for example in batch:
                example_Loss, example_states, acc = Parser.exploration(
                    example,
                    fm,
                    network,
                    alpha=alpha,
                    beta=beta,
                    unk_param=unk_param)
                total_states += example_states
                if batch_loss is not None:
                    batch_loss += example_Loss
                else:
                    batch_loss = example_Loss
                training_acc += acc
            if GlobalNames.use_gpu:
                total_cost += batch_loss.cpu().data.numpy()[0]
            else:
                total_cost += batch_loss.data.numpy()[0]
            batch_loss.backward()
            optimizer.step()

            mean_cost = total_cost / total_states

            print(
                '\rBatch {}  Mean Cost {:.4f} [Train: {}]'.format(
                    b,
                    mean_cost,
                    training_acc,
                ),
                end='',
            )
            sys.stdout.flush()

            if ((b + 1) % parse_every) == 0 or b == (num_batches - 1):
                dev_acc = Parser.evaluate_corpus(
                    dev_trees,
                    fm,
                    network,
                )
                print(' [Dev: {}]'.format(dev_acc))

                if dev_acc > best_acc:
                    best_acc = dev_acc
                    s = round(dev_acc.fscore(), 2)
                    temp_save_file = model_save_file.replace(
                        '.model', '{}.model'.format(s))
                    torch.save(network, temp_save_file)
                    print(' [saved model: {}]'.format(temp_save_file))
                    # rerank(fm,args)
        current_time = time.time()
        runmins = (current_time - start_time) / 60.
        print('  Elapsed time: {:.2f}m'.format(runmins))
Exemple #8
0
    def label(self, nonterminals=[]):

        for nt in nonterminals:
            (left, right, trees) = self.stack.pop()
            tree = PhraseTree(symbol=nt, children=trees)
            self.stack.append((left, right, [tree]))
Exemple #9
0
 def shift(self):
     j = self.i  # (index of shifted word)
     treelet = PhraseTree(leaf=j)
     self.stack.append((j, j, [treelet]))
     self.i += 1