예제 #1
0
    def annotate(self, discourse):
        conf = SRConfiguration(discourse)
        state = self.model.new_state(conf)
        BeamNode = namedtuple("BeamNode", "session cost")
        fringe = [
            BeamNode(Session(conf, self.transition, state=state), cost=0)
        ]
        hypotheses = []

        next_fringe = []
        while fringe:
            for node in fringe:
                if node.session.terminate():
                    hypotheses.append(node)
                else:
                    valid_action = node.session.valid()
                    for (action, nuclear), prob in self.model.score(
                            node.session.state).items():
                        if action in valid_action:
                            session = copy(node.session)
                            if action == SRTransition.SHIFT:
                                session(action)
                                session.state = self.model.shift(session.state)
                            else:
                                session(action, nuclear=nuclear)
                                session.state = self.model.reduce(
                                    session.state)
                            cost = -math.log(prob)
                            next_fringe.append(
                                BeamNode(session=session,
                                         cost=node.cost + cost))
            fringe = sorted(next_fringe, key=lambda n: n.cost)[:self.beam_size]
            next_fringe = []
        hypotheses.sort(key=lambda n: n.cost)
        high_rank_discourse = hypotheses[0].session.current.discourse
        return high_rank_discourse
예제 #2
0
파일: train.py 프로젝트: cxncu001/cdtparser
def train(cdtb):
    model = build_model(cdtb.train)
    transition = SRTransition()

    num_epoch = config.get(config_section, "num_epoch", rtype=int)
    batch_size = config.get(config_section, "batch_size", rtype=int)
    eval_every = config.get(config_section, "eval_every", rtype=int)
    lr = config.get(config_section, "lr", rtype=float)
    l2 = config.get(config_section, "l2_penalty", rtype=float)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=l2)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6, 12], gamma=0.5)
    model.train()
    optimizer.zero_grad()

    step = 0
    batch = 0
    batch_loss = 0.
    best_model_score = 0.
    model_dir = config.get(config_section, "model_dir")
    for epoch in range(num_epoch):
        epoch += 1
        scheduler.step()
        print("learning rate: %f" % scheduler.get_lr()[0])
        for discourse in np.random.permutation(cdtb.train):
            step += 1
            conf = SRConfiguration(discourse.strip())
            state = model.new_state(conf)
            session = Session(conf, transition, state=state)
            scores = []
            grounds = []
            for label in sr_oracle(discourse):
                action, nuclear = label
                scores.append(model(session.state))
                grounds.append(model.label2idx[label])
                if action == SRTransition.SHIFT:
                    session(action)
                    session.state = model.shift(session.state)
                else:
                    session(action, nuclear=nuclear)
                    session.state = model.reduce(session.state)
            loss = criterion(torch.stack(scores), torch.Tensor(grounds).long())
            loss.backward()
            batch_loss += loss.item()

            if step % batch_size == 0:
                batch += 1
                optimizer.step()
                optimizer.zero_grad()
                print("step %d, epoch: %d, batch: %d, batch loss: %.3f" % (step, epoch, batch, batch_loss / batch_size))
                batch_loss = 0.
                if batch % eval_every == 0:
                    model_score = evaluate(model, cdtb.validate)
                    evaluate(model, cdtb.test)
                    if model_score > best_model_score:
                        best_model_score = model_score
                        with open("%s.%.3f" % (model_dir, model_score), "wb+") as best_model_fd:
                            print("save new best model to %s.%.3f" % (model_dir, model_score))
                            torch.save(model, best_model_fd)
                    model.train()
    # copy best model to model dir
    shutil.copy2("%s.%.3f" % (model_dir, best_model_score), model_dir)