Example #1
0
 def __init__(self, model, beam_size=1):
     self.model = model
     self.model.eval()
     self.transition = SRTransition()
     self.beam_size = beam_size
Example #2
0
def train(cdtb):
    model = build_model(cdtb.train)
    transition = SRTransition()

    num_epoch = config.get(config_section, "num_epoch", rtype=int)
    batch_size = config.get(config_section, "batch_size", rtype=int)
    eval_every = config.get(config_section, "eval_every", rtype=int)
    lr = config.get(config_section, "lr", rtype=float)
    l2 = config.get(config_section, "l2_penalty", rtype=float)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=l2)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6, 12], gamma=0.5)
    model.train()
    optimizer.zero_grad()

    step = 0
    batch = 0
    batch_loss = 0.
    best_model_score = 0.
    model_dir = config.get(config_section, "model_dir")
    for epoch in range(num_epoch):
        epoch += 1
        scheduler.step()
        print("learning rate: %f" % scheduler.get_lr()[0])
        for discourse in np.random.permutation(cdtb.train):
            step += 1
            conf = SRConfiguration(discourse.strip())
            state = model.new_state(conf)
            session = Session(conf, transition, state=state)
            scores = []
            grounds = []
            for label in sr_oracle(discourse):
                action, nuclear = label
                scores.append(model(session.state))
                grounds.append(model.label2idx[label])
                if action == SRTransition.SHIFT:
                    session(action)
                    session.state = model.shift(session.state)
                else:
                    session(action, nuclear=nuclear)
                    session.state = model.reduce(session.state)
            loss = criterion(torch.stack(scores), torch.Tensor(grounds).long())
            loss.backward()
            batch_loss += loss.item()

            if step % batch_size == 0:
                batch += 1
                optimizer.step()
                optimizer.zero_grad()
                print("step %d, epoch: %d, batch: %d, batch loss: %.3f" % (step, epoch, batch, batch_loss / batch_size))
                batch_loss = 0.
                if batch % eval_every == 0:
                    model_score = evaluate(model, cdtb.validate)
                    evaluate(model, cdtb.test)
                    if model_score > best_model_score:
                        best_model_score = model_score
                        with open("%s.%.3f" % (model_dir, model_score), "wb+") as best_model_fd:
                            print("save new best model to %s.%.3f" % (model_dir, model_score))
                            torch.save(model, best_model_fd)
                    model.train()
    # copy best model to model dir
    shutil.copy2("%s.%.3f" % (model_dir, best_model_score), model_dir)